From 3c5e464fc238d7912555d4b520c0db771437ad12 Mon Sep 17 00:00:00 2001 From: anastasios Date: Thu, 30 Apr 2026 06:30:53 +0000 Subject: [PATCH 01/16] (examples) triangular inverse --- .../kernels/aic/kernel_gemm_tile.cpp | 150 ++++++++++++++++++ .../kernels/aiv/kernel_tile_add.cpp | 107 +++++++++++++ .../kernels/orchestration/bgemm_orch.cpp | 121 ++++++++++++++ .../test_triangular_inverse.py | 99 ++++++++++++ 4 files changed, 477 insertions(+) create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp new file mode 100644 index 000000000..1f331d6e0 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * Tile-based Matrix Multiplication Kernel (Cube Core) + * + * Computes: output = input_a @ input_b (tile_size x tile_size tile matmul) + * Uses TMATMUL instruction + * + * Tile size is determined by golden.py configuration and passed through + * tensor shapes from orchestration. + * + * Args (Tensor*): + * args[0] = input_a (INPUT) + * args[1] = input_b (INPUT) + * args[2] = output (OUTPUT) + * args[3] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] + */ + +#include +#include +#include +#include + +#include "tensor.h" + +using namespace pto; + +#include "pipe_sync.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +template +static __aicore__ void gemm_tile_impl(__gm__ float *input_a, __gm__ float *input_b, __gm__ float *output) { + constexpr int blockAlign = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, blockAlign); + constexpr int N = CeilAlign(TILE, blockAlign); + + using GlobalDataA = + GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataB = + GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using GlobalDataC = + GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + + GlobalDataA src0Global(input_a); + GlobalDataB src1Global(input_b); + GlobalDataC dstGlobal(output); + + using TileMatA = Tile; + using TileMatB = Tile; + + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + TileMatA aMatTile; + TileMatB bMatTile; + TASSIGN(aMatTile, 0x0); + TASSIGN(bMatTile, 0x20000); + + LeftTile aTile; + RightTile bTile; + AccTile cTile; + TASSIGN(aTile, 0x0); + TASSIGN(bTile, 0x0); + TASSIGN(cTile, 0x0); + + TLOAD(aMatTile, src0Global); + TLOAD(bMatTile, src1Global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(aTile, aMatTile); + TMOV(bTile, bMatTile); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(cTile, aTile, bTile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(dstGlobal, cTile); + + pipe_sync(); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *input_a = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *input_b = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *output = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[3]); + + __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); + uint64_t tile_size = static_cast(cfg[0]); + uint64_t tile_elems = tile_size * tile_size; + int num_tiles = static_cast(cfg[3]); + + __gm__ float *base_a = reinterpret_cast<__gm__ float *>(input_a->buffer.addr) + input_a->start_offset; + __gm__ float *base_b = reinterpret_cast<__gm__ float *>(input_b->buffer.addr) + input_b->start_offset; + __gm__ float *base_c = reinterpret_cast<__gm__ float *>(output->buffer.addr) + output->start_offset; + + for (int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { + __gm__ float *a_ptr = base_a + (tile_idx * tile_elems); + __gm__ float *b_ptr = base_b + (tile_idx * tile_elems); + __gm__ float *c_ptr = base_c + (tile_idx * tile_elems); + + switch (tile_size) { + case 16: + gemm_tile_impl<16>(a_ptr, b_ptr, c_ptr); + break; + case 32: + gemm_tile_impl<32>(a_ptr, b_ptr, c_ptr); + break; + case 64: + gemm_tile_impl<64>(a_ptr, b_ptr, c_ptr); + break; + case 128: + gemm_tile_impl<128>(a_ptr, b_ptr, c_ptr); + break; + default: + break; + } + } +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp new file mode 100644 index 000000000..c80e88244 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * Tile-based Element-wise Addition Kernel (Vector Core) - INOUT Pattern + * + * Computes: C_tile = C_tile + P (tile_size x tile_size tile accumulation) + * Uses TADD instruction + * + * Tile size is determined by golden.py configuration and passed through + * tensor shapes from orchestration. + * + * Args (Tensor*): + * args[0] = C_tile (INOUT: read + write accumulator) + * args[1] = P (INPUT: matmul result to accumulate) + * args[2] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] + */ + +#include +#include +#include + +#include "tensor.h" + +using namespace pto; + +#include "pipe_sync.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +static __aicore__ void tile_add_impl(__gm__ float *c_ptr, __gm__ float *p_ptr) { + using DynShapeDim5 = Shape<1, 1, 1, TILE, TILE>; + using DynStridDim5 = Stride<1, 1, 1, TILE, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData cTile(TILE, TILE); + TileData pTile(TILE, TILE); + TileData outTile(TILE, TILE); + TASSIGN(cTile, 0x0); + TASSIGN(pTile, 0x10000); + TASSIGN(outTile, 0x20000); + + GlobalData cGlobal(c_ptr); + GlobalData pGlobal(p_ptr); + GlobalData outGlobal(c_ptr); // write back to same C location + + TLOAD(cTile, cGlobal); + TLOAD(pTile, pGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(outTile, cTile, pTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(outGlobal, outTile); + pipe_sync(); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *c_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *p_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[2]); + + __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); + uint64_t tile_size = static_cast(cfg[0]); + uint64_t tile_elems = tile_size * tile_size; + int num_tiles = static_cast(cfg[3]); + + __gm__ float *base_c = reinterpret_cast<__gm__ float *>(c_tensor->buffer.addr) + c_tensor->start_offset; + __gm__ float *base_p = reinterpret_cast<__gm__ float *>(p_tensor->buffer.addr) + p_tensor->start_offset; + + for (int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { + __gm__ float *c_ptr = base_c + (tile_idx * tile_elems); + __gm__ float *p_ptr = base_p + (tile_idx * tile_elems); + + switch (tile_size) { + case 16: + tile_add_impl<16>(c_ptr, p_ptr); + break; + case 32: + tile_add_impl<32>(c_ptr, p_ptr); + break; + case 64: + tile_add_impl<64>(c_ptr, p_ptr); + break; + case 128: + tile_add_impl<128>(c_ptr, p_ptr); + break; + default: + break; + } + } +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp new file mode 100644 index 000000000..116487942 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * BGEMM Orchestration Function (tensormap_and_ringbuffer Runtime) + * + * Builds the task graph for tiled matrix multiplication: C = A @ B + * + * Configuration read from scalar args (set in golden.py): + * - tile_size: tile dimension (tile_size x tile_size per tile) + * - grid_k: number of K-dimension partitions + * - num_groups: number of independent groups (= matmul_add_task_num / grid_k) + * - incore_loop: number of tiles per group + * + * Memory layout (tile-first, flattened): + * A: [num_groups, grid_k, incore_loop, tile_size, tile_size] + * B: [num_groups, grid_k, incore_loop, tile_size, tile_size] + * C: [incore_loop * num_groups, tile_size, tile_size] + * + * Arg layout: [A, B, C, config] + */ + +#include +#include + +#include "pto_orchestration_api.h" // NOLINT(build/include_subdir) + +#define FUNC_GEMM_TILE 0 +#define FUNC_TILE_ADD 1 + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; // NOLINT(readability/casting) + return PTO2OrchestrationConfig{ + .expected_arg_count = 4, + }; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + // Tensor args + Tensor ext_A = from_tensor_arg(orch_args.tensor(0)); + Tensor ext_B = from_tensor_arg(orch_args.tensor(1)); + Tensor ext_C = from_tensor_arg(orch_args.tensor(2)); + Tensor ext_config = from_tensor_arg(orch_args.tensor(3)); + + // Read config from tensor data: [tile_size, grid_k, num_groups, incore_loop] + int64_t *host_config = orch_args.tensor(3).data_as(); + int tile_size = static_cast(host_config[0]); + int grid_k = static_cast(host_config[1]); + int num_groups = static_cast(host_config[2]); + int incore_loop = static_cast(host_config[3]); + uint64_t tile_elems = static_cast(tile_size) * tile_size; + + int grid_m = 1; + int grid_n = 1; + + LOG_INFO_V0( + "[bgemm_orch] tile_size: %d, grid_m: %d, grid_n: %d, grid_k: %d, num_groups: %d, incore_loop: %d", tile_size, + grid_m, grid_n, grid_k, num_groups, incore_loop + ); + + uint32_t tile_shapes[1] = {static_cast(tile_elems)}; + uint64_t group_tile_elems = static_cast(incore_loop) * tile_elems; + uint32_t group_shapes[1] = {static_cast(group_tile_elems)}; + TensorCreateInfo group_ci(group_shapes, 1, DataType::FLOAT32); + + int total_gemm = 0; + int total_add = 0; + + // A/B layout: [num_groups, grid_k, incore_loop, tile_size, tile_size] + // C layout: [incore_loop * num_groups, tile_size, tile_size] + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + PTO2_SCOPE_GUARD(); + + uint32_t c_elem_offset = static_cast(static_cast(group_idx) * group_tile_elems); + uint32_t c_view_offsets[1] = {c_elem_offset}; + Tensor C_view = ext_C.view(group_shapes, c_view_offsets); + + for (int k_idx = 0; k_idx < grid_k; k_idx++) { + // In layout [num_groups, grid_k, incore_loop, tile_size, tile_size], + // offset = (group_idx * grid_k + k_idx) * incore_loop * tile_elems + uint64_t ab_offset = + (static_cast(group_idx) * grid_k + static_cast(k_idx)) * group_tile_elems; + + uint32_t a_view_offsets[1] = {static_cast(ab_offset)}; + Tensor A_view = ext_A.view(group_shapes, a_view_offsets); + uint32_t b_view_offsets[1] = {static_cast(ab_offset)}; + Tensor B_view = ext_B.view(group_shapes, b_view_offsets); + Arg params_gemm; + params_gemm.add_input(A_view); + params_gemm.add_input(B_view); + params_gemm.add_output(group_ci); + params_gemm.add_input(ext_config); + TaskOutputTensors gemm_outs = rt_submit_aic_task(FUNC_GEMM_TILE, params_gemm); + total_gemm++; + + Arg params_add; + params_add.add_inout(C_view); + params_add.add_input(gemm_outs.get_ref(0)); + params_add.add_input(ext_config); + rt_submit_aiv_task(FUNC_TILE_ADD, params_add); + total_add++; + } + } + + LOG_INFO_V0( + "[bgemm_orch] Submitted %d gemm tasks and %d add tasks (%d total)", total_gemm, total_add, + total_gemm + total_add + ); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py new file mode 100644 index 000000000..ccd3323b1 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Test triangular inverse cube-only method: runtime-configurable C = torch.linalg.triangular_solve(A, torch.eye(A.shape[-1])).""" + +import torch +from simpler.task_interface import ArgDirection as D + +from simpler_setup import SceneTestCase, TaskArgsBuilder, Tensor, scene_test + + +@scene_test(level=2, runtime="tensormap_and_ringbuffer") +class TestTriangularInverse(SceneTestCase): + RTOL = 1e-3 + ATOL = 1e-3 + + CALLABLE = { + "orchestration": { + "source": "kernels/orchestration/bgemm_orch.cpp", + "function_name": "aicpu_orchestration_entry", + "signature": [D.IN, D.IN, D.OUT, D.IN], + }, + "incores": [ + { + "func_id": 0, + "name": "GEMM", + "source": "kernels/aic/kernel_gemm_tile.cpp", + "core_type": "aic", + "signature": [D.IN, D.IN, D.OUT], + }, + { + "func_id": 1, + "name": "ADD", + "source": "kernels/aiv/kernel_tile_add.cpp", + "core_type": "aiv", + "signature": [D.INOUT, D.IN], + }, + ], + } + + CASES = [ + { + "name": "Case0", + "platforms": ["a2a3sim", "a2a3"], + "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "params": {"matmul_add_task_num": 500, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + }, + { + "name": "Case1", + "manual": True, + "platforms": ["a2a3sim", "a2a3"], + "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "params": {"matmul_add_task_num": 64, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + }, + { + "name": "Case2", + "manual": True, + "platforms": ["a2a3sim", "a2a3"], + "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "params": {"matmul_add_task_num": 256, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + }, + ] + + def generate_args(self, params): + tile_size = params["incore_data_size"] + incore_loop = params["incore_loop"] + grid_k = params["grid_k"] + num_groups = params["matmul_add_task_num"] // grid_k + A = torch.randn(num_groups, grid_k, incore_loop, tile_size, tile_size, dtype=torch.float32) * 0.01 + B = torch.randn(num_groups, grid_k, incore_loop, tile_size, tile_size, dtype=torch.float32) * 0.01 + C = torch.zeros(incore_loop * num_groups, tile_size, tile_size, dtype=torch.float32) + config = torch.tensor([tile_size, grid_k, num_groups, incore_loop], dtype=torch.int64) + return TaskArgsBuilder( + Tensor("A", A.flatten()), Tensor("B", B.flatten()), Tensor("C", C.flatten()), Tensor("config", config) + ) + + def compute_golden(self, args, params): + tile_size = params["incore_data_size"] + incore_loop = params["incore_loop"] + grid_k = params["grid_k"] + num_groups = params["matmul_add_task_num"] // grid_k + A = args.A.reshape(num_groups, grid_k, incore_loop, tile_size, tile_size) + B = args.B.reshape(num_groups, grid_k, incore_loop, tile_size, tile_size) + C = args.C.reshape(incore_loop * num_groups, tile_size, tile_size) + C[:] = 0.0 + for group in range(num_groups): + for k_idx in range(grid_k): + for i in range(incore_loop): + C[group * incore_loop + i] += torch.matmul(A[group, k_idx, i], B[group, k_idx, i]) + + +if __name__ == "__main__": + SceneTestCase.run_module(__name__) From b819651ad6053b7950f5920ed91ada2af7640509 Mon Sep 17 00:00:00 2001 From: anastasios Date: Thu, 30 Apr 2026 07:49:58 +0000 Subject: [PATCH 02/16] wip --- Makefile | 4 + .../kernels/aic/kernel_gemm_tile.cpp | 150 -------------- .../kernels/aic/kernel_simple_matmul.cpp | 184 ++++++++++++++++++ .../kernels/aiv/kernel_tile_add.cpp | 107 ---------- ...m_orch.cpp => triangular_inverse_orch.cpp} | 49 +---- .../test_triangular_inverse.py | 36 ++-- 6 files changed, 207 insertions(+), 323 deletions(-) create mode 100644 Makefile delete mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp delete mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp rename examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/{bgemm_orch.cpp => triangular_inverse_orch.cpp} (62%) diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..0ed2ee243 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ + + +all: + python examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py -p a2a3sim diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp deleted file mode 100644 index 1f331d6e0..000000000 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_gemm_tile.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright (c) PyPTO Contributors. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ----------------------------------------------------------------------------------------------------------- - */ -/** - * Tile-based Matrix Multiplication Kernel (Cube Core) - * - * Computes: output = input_a @ input_b (tile_size x tile_size tile matmul) - * Uses TMATMUL instruction - * - * Tile size is determined by golden.py configuration and passed through - * tensor shapes from orchestration. - * - * Args (Tensor*): - * args[0] = input_a (INPUT) - * args[1] = input_b (INPUT) - * args[2] = output (OUTPUT) - * args[3] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] - */ - -#include -#include -#include -#include - -#include "tensor.h" - -using namespace pto; - -#include "pipe_sync.h" - -#ifndef __gm__ -#define __gm__ -#endif - -#ifndef __aicore__ -#define __aicore__ [aicore] -#endif - -template -AICORE constexpr inline T CeilAlign(T num_1, T num_2) { - if (num_2 == 0) { - return 0; - } - return (num_1 + num_2 - 1) / num_2 * num_2; -} - -template -static __aicore__ void gemm_tile_impl(__gm__ float *input_a, __gm__ float *input_b, __gm__ float *output) { - constexpr int blockAlign = C0_SIZE_BYTE / sizeof(float); - constexpr int M = CeilAlign(TILE, 16); - constexpr int K = CeilAlign(TILE, blockAlign); - constexpr int N = CeilAlign(TILE, blockAlign); - - using GlobalDataA = - GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; - using GlobalDataB = - GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; - using GlobalDataC = - GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; - - GlobalDataA src0Global(input_a); - GlobalDataB src1Global(input_b); - GlobalDataC dstGlobal(output); - - using TileMatA = Tile; - using TileMatB = Tile; - - using LeftTile = TileLeft; - using RightTile = TileRight; - using AccTile = TileAcc; - - TileMatA aMatTile; - TileMatB bMatTile; - TASSIGN(aMatTile, 0x0); - TASSIGN(bMatTile, 0x20000); - - LeftTile aTile; - RightTile bTile; - AccTile cTile; - TASSIGN(aTile, 0x0); - TASSIGN(bTile, 0x0); - TASSIGN(cTile, 0x0); - - TLOAD(aMatTile, src0Global); - TLOAD(bMatTile, src1Global); - - set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - - TMOV(aTile, aMatTile); - TMOV(bTile, bMatTile); - - set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); - wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); - - TMATMUL(cTile, aTile, bTile); - - set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); - wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); - - TSTORE(dstGlobal, cTile); - - pipe_sync(); -} - -extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { - __gm__ Tensor *input_a = reinterpret_cast<__gm__ Tensor *>(args[0]); - __gm__ Tensor *input_b = reinterpret_cast<__gm__ Tensor *>(args[1]); - __gm__ Tensor *output = reinterpret_cast<__gm__ Tensor *>(args[2]); - __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[3]); - - __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); - uint64_t tile_size = static_cast(cfg[0]); - uint64_t tile_elems = tile_size * tile_size; - int num_tiles = static_cast(cfg[3]); - - __gm__ float *base_a = reinterpret_cast<__gm__ float *>(input_a->buffer.addr) + input_a->start_offset; - __gm__ float *base_b = reinterpret_cast<__gm__ float *>(input_b->buffer.addr) + input_b->start_offset; - __gm__ float *base_c = reinterpret_cast<__gm__ float *>(output->buffer.addr) + output->start_offset; - - for (int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { - __gm__ float *a_ptr = base_a + (tile_idx * tile_elems); - __gm__ float *b_ptr = base_b + (tile_idx * tile_elems); - __gm__ float *c_ptr = base_c + (tile_idx * tile_elems); - - switch (tile_size) { - case 16: - gemm_tile_impl<16>(a_ptr, b_ptr, c_ptr); - break; - case 32: - gemm_tile_impl<32>(a_ptr, b_ptr, c_ptr); - break; - case 64: - gemm_tile_impl<64>(a_ptr, b_ptr, c_ptr); - break; - case 128: - gemm_tile_impl<128>(a_ptr, b_ptr, c_ptr); - break; - default: - break; - } - } -} diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp new file mode 100644 index 000000000..c0cbeddd2 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * Tile-based Matrix Multiplication Kernel (Cube Core) + * + * Computes: output = input_a @ input_b (tile_size x tile_size tile matmul) + * Uses TMATMUL instruction + * + * Tile size is determined by golden.py configuration and passed through + * tensor shapes from orchestration. + * + * Args (Tensor*): + * args[0] = input_a (INPUT) + * args[1] = input_b (INPUT) + * args[2] = output (OUTPUT) + * args[3] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] + */ + +#include +#include +#include +#include + +#include "tensor.h" + +using namespace pto; + +#include "pipe_sync.h" + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +using namespace pto; + +constexpr unsigned NUM_BLOCKS = 20; // number of AICs +constexpr unsigned UB_SIZE = 0x30000; // 192KB UB of A2A3 + +template +AICORE inline void SetFlag(uint32_t id) { + set_flag(SrcPipe, DstPipe, static_cast(id)); +} +template +AICORE inline void WaitFlag(uint32_t id) { + wait_flag(SrcPipe, DstPipe, static_cast(id)); +} + +template +AICORE void runKernelSimpleMatMul(__gm__ InputT *a, __gm__ InputT *b, __gm__ OutputT *c) { + constexpr uint32_t tile_len = matrix_size * matrix_size; + + /* Global Memory / Tensors */ + using TensorShapeIn = TileShape2D; + using TensorStridesIn = BaseShape2D; + using GlobalTensorIn = GlobalTensor; + + using TensorShapeOut = TileShape2D; + using TensorStridesOut = BaseShape2D; + using GlobalTensorOut = GlobalTensor; + + /* L1 Memory */ + using TileL1AB = Tile< + TileType::Mat, InputT, matrix_size, matrix_size, BLayout::ColMajor, matrix_size, matrix_size, SLayout::RowMajor, + 512>; + + /* L0 Memory */ + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + + GlobalTensorIn a_global_in(a); + GlobalTensorIn b_global_in(b); + GlobalTensorOut c_global_out(c); + TASSIGN(a_global_in, a); + TASSIGN(b_global_in, b); + TASSIGN(c_global_out, c); + + TileL1AB a_l1_tile; + TileL1AB b_l1_tile; + TASSIGN(a_l1_tile, 0x0); + TASSIGN(b_l1_tile, 0x0 + tile_len * sizeof(InputT)); + + TileL0A a_l0_tile; + TileL0B b_l0_tile; + TileL0C c_l0_tile; + // L0A/L0B/L0C are distinct scratchpads + TASSIGN(a_l0_tile, 0x0); + TASSIGN(b_l0_tile, 0x0); + TASSIGN(c_l0_tile, 0x0); + + // LOAD matrix A from GM -> L1 (MTE2) + TLOAD(a_l1_tile, a_global_in); + TLOAD(b_l1_tile, b_global_in); + SetFlag(0); + WaitFlag(0); + + // Copy A from L1 -> L0 (MTE1) + // MatMul unit waits (using id:0) for MTE1 to load matrices into L0A/B + TMOV(a_l0_tile, a_l1_tile); + // Copy B from L1 -> L0B + // MatMul unit waits (using id:1) for MTE1 to load matrices into L0A/B + TMOV(b_l0_tile, b_l1_tile); + SetFlag(0); // MTE1 pipe sets flag for MM pipe + WaitFlag(0); // MM pipe waits for MTE1 pipe to set flag + + // MATMUL (M) + TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); + pipe_barrier(PIPE_ALL); + SetFlag(0); // M pipe sets flag for FIX pipe + WaitFlag(0); // FIX pipe waits for M pipe to set flag + TSTORE(c_global_out, c_l0_tile); +} + +template +AICORE void run_simple_matmul(__gm__ T *a, __gm__ T *b, __gm__ float *c, uint32_t matrix_size) { + static_assert( + std::is_same_v or std::is_same_v or std::is_same_v, + "simple_matmul supports only fp16/bf16/fp32." + ); + + switch (matrix_size) { + case 16: + runKernelSimpleMatMul(a, b, c); + break; + case 32: + runKernelSimpleMatMul(a, b, c); + break; + + case 64: + runKernelSimpleMatMul(a, b, c); + break; + + case 96: + runKernelSimpleMatMul(a, b, c); + break; + + case 128: + runKernelSimpleMatMul(a, b, c); + break; + } +} + +/** + * Element-wise multiplication kernel implementation + * + * Unified signature: all arguments passed via int64_t array + * @param args Argument array: + * args[0] = src0 pointer (first input tensor) + * args[1] = src1 pointer (second input tensor) + * args[2] = out pointer (output tensor) + * args[3] = size (number of elements) + */ +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + // Unpack arguments (Tensor* pointers from runtime) + __gm__ Tensor *a = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *b = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *c = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[3]); + + __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); + const uint64_t tile_size = static_cast(cfg[0]); + uint64_t tile_elems = tile_size * tile_size; + const int num_tiles = static_cast(cfg[3]); +} \ No newline at end of file diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp deleted file mode 100644 index c80e88244..000000000 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aiv/kernel_tile_add.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) PyPTO Contributors. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ----------------------------------------------------------------------------------------------------------- - */ -/** - * Tile-based Element-wise Addition Kernel (Vector Core) - INOUT Pattern - * - * Computes: C_tile = C_tile + P (tile_size x tile_size tile accumulation) - * Uses TADD instruction - * - * Tile size is determined by golden.py configuration and passed through - * tensor shapes from orchestration. - * - * Args (Tensor*): - * args[0] = C_tile (INOUT: read + write accumulator) - * args[1] = P (INPUT: matmul result to accumulate) - * args[2] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] - */ - -#include -#include -#include - -#include "tensor.h" - -using namespace pto; - -#include "pipe_sync.h" - -#ifndef __gm__ -#define __gm__ -#endif - -#ifndef __aicore__ -#define __aicore__ [aicore] -#endif - -template -static __aicore__ void tile_add_impl(__gm__ float *c_ptr, __gm__ float *p_ptr) { - using DynShapeDim5 = Shape<1, 1, 1, TILE, TILE>; - using DynStridDim5 = Stride<1, 1, 1, TILE, 1>; - using GlobalData = GlobalTensor; - using TileData = Tile; - - TileData cTile(TILE, TILE); - TileData pTile(TILE, TILE); - TileData outTile(TILE, TILE); - TASSIGN(cTile, 0x0); - TASSIGN(pTile, 0x10000); - TASSIGN(outTile, 0x20000); - - GlobalData cGlobal(c_ptr); - GlobalData pGlobal(p_ptr); - GlobalData outGlobal(c_ptr); // write back to same C location - - TLOAD(cTile, cGlobal); - TLOAD(pTile, pGlobal); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - TADD(outTile, cTile, pTile); - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - TSTORE(outGlobal, outTile); - pipe_sync(); -} - -extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { - __gm__ Tensor *c_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); - __gm__ Tensor *p_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); - __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[2]); - - __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); - uint64_t tile_size = static_cast(cfg[0]); - uint64_t tile_elems = tile_size * tile_size; - int num_tiles = static_cast(cfg[3]); - - __gm__ float *base_c = reinterpret_cast<__gm__ float *>(c_tensor->buffer.addr) + c_tensor->start_offset; - __gm__ float *base_p = reinterpret_cast<__gm__ float *>(p_tensor->buffer.addr) + p_tensor->start_offset; - - for (int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { - __gm__ float *c_ptr = base_c + (tile_idx * tile_elems); - __gm__ float *p_ptr = base_p + (tile_idx * tile_elems); - - switch (tile_size) { - case 16: - tile_add_impl<16>(c_ptr, p_ptr); - break; - case 32: - tile_add_impl<32>(c_ptr, p_ptr); - break; - case 64: - tile_add_impl<64>(c_ptr, p_ptr); - break; - case 128: - tile_add_impl<128>(c_ptr, p_ptr); - break; - default: - break; - } - } -} diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp similarity index 62% rename from examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp rename to examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp index 116487942..d5001012e 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/bgemm_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp @@ -73,49 +73,12 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip uint32_t group_shapes[1] = {static_cast(group_tile_elems)}; TensorCreateInfo group_ci(group_shapes, 1, DataType::FLOAT32); - int total_gemm = 0; - int total_add = 0; - - // A/B layout: [num_groups, grid_k, incore_loop, tile_size, tile_size] - // C layout: [incore_loop * num_groups, tile_size, tile_size] - for (int group_idx = 0; group_idx < num_groups; group_idx++) { - PTO2_SCOPE_GUARD(); - - uint32_t c_elem_offset = static_cast(static_cast(group_idx) * group_tile_elems); - uint32_t c_view_offsets[1] = {c_elem_offset}; - Tensor C_view = ext_C.view(group_shapes, c_view_offsets); - - for (int k_idx = 0; k_idx < grid_k; k_idx++) { - // In layout [num_groups, grid_k, incore_loop, tile_size, tile_size], - // offset = (group_idx * grid_k + k_idx) * incore_loop * tile_elems - uint64_t ab_offset = - (static_cast(group_idx) * grid_k + static_cast(k_idx)) * group_tile_elems; - - uint32_t a_view_offsets[1] = {static_cast(ab_offset)}; - Tensor A_view = ext_A.view(group_shapes, a_view_offsets); - uint32_t b_view_offsets[1] = {static_cast(ab_offset)}; - Tensor B_view = ext_B.view(group_shapes, b_view_offsets); - Arg params_gemm; - params_gemm.add_input(A_view); - params_gemm.add_input(B_view); - params_gemm.add_output(group_ci); - params_gemm.add_input(ext_config); - TaskOutputTensors gemm_outs = rt_submit_aic_task(FUNC_GEMM_TILE, params_gemm); - total_gemm++; - - Arg params_add; - params_add.add_inout(C_view); - params_add.add_input(gemm_outs.get_ref(0)); - params_add.add_input(ext_config); - rt_submit_aiv_task(FUNC_TILE_ADD, params_add); - total_add++; - } - } - - LOG_INFO_V0( - "[bgemm_orch] Submitted %d gemm tasks and %d add tasks (%d total)", total_gemm, total_add, - total_gemm + total_add - ); + Arg params_gemm; + params_gemm.add_input(ext_A); + params_gemm.add_input(ext_B); + params_gemm.add_output(ext_C); + params_gemm.add_input(ext_config); + TaskOutputTensors gemm_outs = rt_submit_aic_task(FUNC_GEMM_TILE, params_gemm); } } // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index ccd3323b1..33cf88ae8 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -22,7 +22,7 @@ class TestTriangularInverse(SceneTestCase): CALLABLE = { "orchestration": { - "source": "kernels/orchestration/bgemm_orch.cpp", + "source": "kernels/orchestration/triangular_inverse_orch.cpp", "function_name": "aicpu_orchestration_entry", "signature": [D.IN, D.IN, D.OUT, D.IN], }, @@ -30,17 +30,10 @@ class TestTriangularInverse(SceneTestCase): { "func_id": 0, "name": "GEMM", - "source": "kernels/aic/kernel_gemm_tile.cpp", + "source": "kernels/aic/kernel_simple_matmul.cpp", "core_type": "aic", "signature": [D.IN, D.IN, D.OUT], - }, - { - "func_id": 1, - "name": "ADD", - "source": "kernels/aiv/kernel_tile_add.cpp", - "core_type": "aiv", - "signature": [D.INOUT, D.IN], - }, + } ], } @@ -49,21 +42,21 @@ class TestTriangularInverse(SceneTestCase): "name": "Case0", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 500, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, }, { "name": "Case1", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 64, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, }, { "name": "Case2", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 256, "incore_data_size": 128, "incore_loop": 4, "grid_k": 2}, + "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, }, ] @@ -72,9 +65,9 @@ def generate_args(self, params): incore_loop = params["incore_loop"] grid_k = params["grid_k"] num_groups = params["matmul_add_task_num"] // grid_k - A = torch.randn(num_groups, grid_k, incore_loop, tile_size, tile_size, dtype=torch.float32) * 0.01 - B = torch.randn(num_groups, grid_k, incore_loop, tile_size, tile_size, dtype=torch.float32) * 0.01 - C = torch.zeros(incore_loop * num_groups, tile_size, tile_size, dtype=torch.float32) + A = torch.randn(tile_size, tile_size, dtype=torch.float32) * 0.1 + B = torch.randn(tile_size, tile_size, dtype=torch.float32) * 0.1 + C = torch.zeros(tile_size, tile_size, dtype=torch.float32) config = torch.tensor([tile_size, grid_k, num_groups, incore_loop], dtype=torch.int64) return TaskArgsBuilder( Tensor("A", A.flatten()), Tensor("B", B.flatten()), Tensor("C", C.flatten()), Tensor("config", config) @@ -85,14 +78,11 @@ def compute_golden(self, args, params): incore_loop = params["incore_loop"] grid_k = params["grid_k"] num_groups = params["matmul_add_task_num"] // grid_k - A = args.A.reshape(num_groups, grid_k, incore_loop, tile_size, tile_size) - B = args.B.reshape(num_groups, grid_k, incore_loop, tile_size, tile_size) - C = args.C.reshape(incore_loop * num_groups, tile_size, tile_size) + A = args.A.reshape(tile_size, tile_size) + B = args.B.reshape(tile_size, tile_size) + C = args.C.reshape(tile_size, tile_size) C[:] = 0.0 - for group in range(num_groups): - for k_idx in range(grid_k): - for i in range(incore_loop): - C[group * incore_loop + i] += torch.matmul(A[group, k_idx, i], B[group, k_idx, i]) + C = A @ B if __name__ == "__main__": From c1914ebf2b6e89c1537700bafd8814113ebfb151 Mon Sep 17 00:00:00 2001 From: anastasios Date: Thu, 30 Apr 2026 12:56:35 +0000 Subject: [PATCH 03/16] fix --- .../benchmark_bgemm/test_benchmark_bgemm.py | 8 ++--- .../test_triangular_inverse.py | 29 ++++++++----------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py b/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py index a3b888f75..d8b073922 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py +++ b/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py @@ -17,8 +17,8 @@ @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestBenchmarkBgemm(SceneTestCase): - RTOL = 1e-3 - ATOL = 1e-3 + RTOL = 1e-5 + ATOL = 1e-5 CALLABLE = { "orchestration": { @@ -32,14 +32,14 @@ class TestBenchmarkBgemm(SceneTestCase): "name": "GEMM", "source": "kernels/aic/kernel_gemm_tile.cpp", "core_type": "aic", - "signature": [D.IN, D.IN, D.OUT], + "signature": [D.IN, D.IN, D.OUT, D.IN], }, { "func_id": 1, "name": "ADD", "source": "kernels/aiv/kernel_tile_add.cpp", "core_type": "aiv", - "signature": [D.INOUT, D.IN], + "signature": [D.INOUT, D.IN, D.IN], }, ], } diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 33cf88ae8..72afbe544 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -42,47 +42,42 @@ class TestTriangularInverse(SceneTestCase): "name": "Case0", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, + "params": {"batch_dim": 4, "incore_data_size": 128}, }, { "name": "Case1", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, + "params": {"batch_dim": 4, "incore_data_size": 128}, }, { "name": "Case2", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 2, "block_dim": 20}, - "params": {"matmul_add_task_num": 1, "incore_data_size": 128, "incore_loop": 1, "grid_k": 1}, + "params": {"batch_dim": 4, "incore_data_size": 128}, }, ] def generate_args(self, params): tile_size = params["incore_data_size"] - incore_loop = params["incore_loop"] - grid_k = params["grid_k"] - num_groups = params["matmul_add_task_num"] // grid_k - A = torch.randn(tile_size, tile_size, dtype=torch.float32) * 0.1 - B = torch.randn(tile_size, tile_size, dtype=torch.float32) * 0.1 - C = torch.zeros(tile_size, tile_size, dtype=torch.float32) - config = torch.tensor([tile_size, grid_k, num_groups, incore_loop], dtype=torch.int64) + batch_dim = params["batch_dim"] + A = torch.randn(batch_dim, tile_size, tile_size, dtype=torch.float32) * 0.1 + B = torch.randn(batch_dim, tile_size, tile_size, dtype=torch.float32) * 0.1 + C = torch.zeros(batch_dim, tile_size, tile_size, dtype=torch.float32) + config = torch.tensor([tile_size, batch_dim], dtype=torch.int64) return TaskArgsBuilder( Tensor("A", A.flatten()), Tensor("B", B.flatten()), Tensor("C", C.flatten()), Tensor("config", config) ) def compute_golden(self, args, params): tile_size = params["incore_data_size"] - incore_loop = params["incore_loop"] - grid_k = params["grid_k"] - num_groups = params["matmul_add_task_num"] // grid_k - A = args.A.reshape(tile_size, tile_size) - B = args.B.reshape(tile_size, tile_size) - C = args.C.reshape(tile_size, tile_size) - C[:] = 0.0 + batch_dim = params["batch_dim"] + A = args.A.reshape(batch_dim, tile_size, tile_size) + B = args.B.reshape(batch_dim, tile_size, tile_size) C = A @ B + return C if __name__ == "__main__": From 40bf56cf55f8af6dffa3ca19a344837f5bb514f2 Mon Sep 17 00:00:00 2001 From: anastasios Date: Mon, 18 May 2026 14:38:55 +0000 Subject: [PATCH 04/16] wip --- Makefile | 4 ++++ .../triangular_inverse_example/test_triangular_inverse.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 0ed2ee243..2f854df67 100644 --- a/Makefile +++ b/Makefile @@ -2,3 +2,7 @@ all: python examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py -p a2a3sim + + +run_on_npu: + python examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py -p a2a3 diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 72afbe544..eff26efcd 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -41,21 +41,21 @@ class TestTriangularInverse(SceneTestCase): { "name": "Case0", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "config": {"aicpu_thread_num": 4, "block_dim": 16}, "params": {"batch_dim": 4, "incore_data_size": 128}, }, { "name": "Case1", "manual": True, "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "config": {"aicpu_thread_num": 4, "block_dim": 16}, "params": {"batch_dim": 4, "incore_data_size": 128}, }, { "name": "Case2", "manual": True, "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 2, "block_dim": 20}, + "config": {"aicpu_thread_num": 4, "block_dim": 16}, "params": {"batch_dim": 4, "incore_data_size": 128}, }, ] From 783e20960dc785f4a6c4ffa82a86b0b4b9589a1f Mon Sep 17 00:00:00 2001 From: Anastasios Zouzias Date: Tue, 19 May 2026 21:06:34 +0200 Subject: [PATCH 05/16] (tri_inv) wip --- .../kernels/aic/kernel_simple_matmul.cpp | 184 ---- .../kernels/aic/kernel_tri_inv_rec_unroll.cpp | 905 ++++++++++++++++++ .../orchestration/triangular_inverse_orch.cpp | 74 +- .../test_triangular_inverse.py | 72 +- 4 files changed, 981 insertions(+), 254 deletions(-) delete mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp deleted file mode 100644 index c0cbeddd2..000000000 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_simple_matmul.cpp +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Copyright (c) PyPTO Contributors. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ----------------------------------------------------------------------------------------------------------- - */ -/** - * Tile-based Matrix Multiplication Kernel (Cube Core) - * - * Computes: output = input_a @ input_b (tile_size x tile_size tile matmul) - * Uses TMATMUL instruction - * - * Tile size is determined by golden.py configuration and passed through - * tensor shapes from orchestration. - * - * Args (Tensor*): - * args[0] = input_a (INPUT) - * args[1] = input_b (INPUT) - * args[2] = output (OUTPUT) - * args[3] = config (INPUT) - int64_t[4]: [tile_size, grid_k, num_groups, incore_loop] - */ - -#include -#include -#include -#include - -#include "tensor.h" - -using namespace pto; - -#include "pipe_sync.h" - -#ifndef __gm__ -#define __gm__ -#endif - -#ifndef __aicore__ -#define __aicore__ [aicore] -#endif - -template -AICORE constexpr inline T CeilAlign(T num_1, T num_2) { - if (num_2 == 0) { - return 0; - } - return (num_1 + num_2 - 1) / num_2 * num_2; -} - -using namespace pto; - -constexpr unsigned NUM_BLOCKS = 20; // number of AICs -constexpr unsigned UB_SIZE = 0x30000; // 192KB UB of A2A3 - -template -AICORE inline void SetFlag(uint32_t id) { - set_flag(SrcPipe, DstPipe, static_cast(id)); -} -template -AICORE inline void WaitFlag(uint32_t id) { - wait_flag(SrcPipe, DstPipe, static_cast(id)); -} - -template -AICORE void runKernelSimpleMatMul(__gm__ InputT *a, __gm__ InputT *b, __gm__ OutputT *c) { - constexpr uint32_t tile_len = matrix_size * matrix_size; - - /* Global Memory / Tensors */ - using TensorShapeIn = TileShape2D; - using TensorStridesIn = BaseShape2D; - using GlobalTensorIn = GlobalTensor; - - using TensorShapeOut = TileShape2D; - using TensorStridesOut = BaseShape2D; - using GlobalTensorOut = GlobalTensor; - - /* L1 Memory */ - using TileL1AB = Tile< - TileType::Mat, InputT, matrix_size, matrix_size, BLayout::ColMajor, matrix_size, matrix_size, SLayout::RowMajor, - 512>; - - /* L0 Memory */ - using TileL0A = TileLeft; - using TileL0B = TileRight; - using TileL0C = TileAcc; - - GlobalTensorIn a_global_in(a); - GlobalTensorIn b_global_in(b); - GlobalTensorOut c_global_out(c); - TASSIGN(a_global_in, a); - TASSIGN(b_global_in, b); - TASSIGN(c_global_out, c); - - TileL1AB a_l1_tile; - TileL1AB b_l1_tile; - TASSIGN(a_l1_tile, 0x0); - TASSIGN(b_l1_tile, 0x0 + tile_len * sizeof(InputT)); - - TileL0A a_l0_tile; - TileL0B b_l0_tile; - TileL0C c_l0_tile; - // L0A/L0B/L0C are distinct scratchpads - TASSIGN(a_l0_tile, 0x0); - TASSIGN(b_l0_tile, 0x0); - TASSIGN(c_l0_tile, 0x0); - - // LOAD matrix A from GM -> L1 (MTE2) - TLOAD(a_l1_tile, a_global_in); - TLOAD(b_l1_tile, b_global_in); - SetFlag(0); - WaitFlag(0); - - // Copy A from L1 -> L0 (MTE1) - // MatMul unit waits (using id:0) for MTE1 to load matrices into L0A/B - TMOV(a_l0_tile, a_l1_tile); - // Copy B from L1 -> L0B - // MatMul unit waits (using id:1) for MTE1 to load matrices into L0A/B - TMOV(b_l0_tile, b_l1_tile); - SetFlag(0); // MTE1 pipe sets flag for MM pipe - WaitFlag(0); // MM pipe waits for MTE1 pipe to set flag - - // MATMUL (M) - TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); - pipe_barrier(PIPE_ALL); - SetFlag(0); // M pipe sets flag for FIX pipe - WaitFlag(0); // FIX pipe waits for M pipe to set flag - TSTORE(c_global_out, c_l0_tile); -} - -template -AICORE void run_simple_matmul(__gm__ T *a, __gm__ T *b, __gm__ float *c, uint32_t matrix_size) { - static_assert( - std::is_same_v or std::is_same_v or std::is_same_v, - "simple_matmul supports only fp16/bf16/fp32." - ); - - switch (matrix_size) { - case 16: - runKernelSimpleMatMul(a, b, c); - break; - case 32: - runKernelSimpleMatMul(a, b, c); - break; - - case 64: - runKernelSimpleMatMul(a, b, c); - break; - - case 96: - runKernelSimpleMatMul(a, b, c); - break; - - case 128: - runKernelSimpleMatMul(a, b, c); - break; - } -} - -/** - * Element-wise multiplication kernel implementation - * - * Unified signature: all arguments passed via int64_t array - * @param args Argument array: - * args[0] = src0 pointer (first input tensor) - * args[1] = src1 pointer (second input tensor) - * args[2] = out pointer (output tensor) - * args[3] = size (number of elements) - */ -extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { - // Unpack arguments (Tensor* pointers from runtime) - __gm__ Tensor *a = reinterpret_cast<__gm__ Tensor *>(args[0]); - __gm__ Tensor *b = reinterpret_cast<__gm__ Tensor *>(args[1]); - __gm__ Tensor *c = reinterpret_cast<__gm__ Tensor *>(args[2]); - __gm__ Tensor *config = reinterpret_cast<__gm__ Tensor *>(args[3]); - - __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config->buffer.addr); - const uint64_t tile_size = static_cast(cfg[0]); - uint64_t tile_elems = tile_size * tile_size; - const int num_tiles = static_cast(cfg[3]); -} \ No newline at end of file diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp new file mode 100644 index 000000000..cad69147c --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp @@ -0,0 +1,905 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +All rights reserved. + +See LICENSE in the root of the software repository: +https://github.com/huawei-csl/pto-kernels/ +for the full License text. +*/ +/* + * Triangular matrix inverse kernel (recursive unrolled algorithm). + * + * Adapted from pto-kernels for the simpler framework: + * - kernel_utils.h content is inlined below + * - constants.h (static -I matrices) is replaced by a tensor arg passed + * from the Python test via the orchestration + * - The pto-kernels __global__ entry points are replaced by a single + * kernel_entry(__gm__ int64_t *args) as required by simpler + * + * Args layout (positional, via Tensor* packed in int64_t array): + * args[0] = M (INPUT) fp16 triangular matrices [num_matrices, N, N] + * args[1] = I_neg (INPUT) fp16 negative identity [N, N] + * args[2] = M_inv (OUTPUT) fp16 result [num_matrices, N, N] + * args[3] = config (INPUT) int64[3]: [matrix_size, num_matrices, is_lower] + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef GM_ADDR +#define GM_ADDR __gm__ uint8_t* +#endif + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +// --------------------------------------------------------------------------- +// Inlined content of kernel_utils.h (from pto-kernels) +// --------------------------------------------------------------------------- +namespace kernel_utils { + +template +AICORE inline void SetWaitFlag(uint32_t id) { + set_flag(SrcPipe, DstPipe, static_cast(id)); + wait_flag(SrcPipe, DstPipe, static_cast(id)); +} + +template ::value && + std::is_integral::value, + int>::type = 0> +AICORE inline T1 CeilDiv(T1 value, T2 divisor) { + return (value + divisor - 1) / divisor; +} + +#define BSND_OFFSET(tile_id, N, S, D) \ + (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) + +AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, + uint32_t num_bsnd_heads, + uint32_t matrix_size) { + return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); +} + +struct BSNDVarlenTileInfo { + uint32_t bsnd_offset; + uint32_t valid_size; +}; + +AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, + __gm__ int32_t* cu_seqlens) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + + uint32_t seq_start = static_cast(cu_seqlens[0]); + uint32_t accumulated_chunks = 0; + for (uint32_t seq_idx = 0;; ++seq_idx) { + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + const uint32_t seq_num_chunks = CeilDiv(seq_len, matrix_size); + if (chunk_idx < accumulated_chunks + seq_num_chunks) { + const uint32_t local_chunk_idx = chunk_idx - accumulated_chunks; + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = + min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, + valid_size}; + } + accumulated_chunks += seq_num_chunks; + seq_start = seq_end; + } +} + +} // namespace kernel_utils + +using namespace kernel_utils; + +// --------------------------------------------------------------------------- +// Kernel template code (verbatim from pto-kernels kernel_tri_inv_rec_unroll.cpp) +// --------------------------------------------------------------------------- + +/** + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each. + * The src matrix lies in L1, while the dst matrix lies either in L0A or L0B. + * This kernel copies only the diagonal blocks (fractals) of size FractalSize * + * FractalSize from the src matrix to the dst matrix. + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + */ +template +AICORE inline void CopyDiagonalFractalsL1ToL0(SrcL1TileT src, DstL0TileT dst) { + constexpr uint32_t NumFractals = MatrixSize / FractalSize; + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + Tile + fractals[NumFractals]; + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < NumFractals; ++i) { + TASSIGN(fractals[i], starting_address + i * FractalSize * + (MatrixSize + FractalSize) * + sizeof(InputT)); + TEXTRACT(fractals[i], src, i * FractalSize, i * FractalSize); + } +} + +/** + * @brief: Takes as input two matrices of size MatrixSize * MatrixSize each, + * and an integer block_size. The src matrix lies in L1, while the dst matrix + * either in L0A or L0B. This method copies some of the diagonal blocks from the + * input to the output as follows: + * - If dst is in L0A (left): copy even diagonal blocks 0, 2, 4, ... + * - If dst is in L0B (right): copy odd blocks 1, 3, 5, ... + * Important note: the dst matrix should be initialized to all-zeros before + * calling this method + * + * @tparam InputT Input data type (fp16). + * @tparam FractalSize Size of each fractal matrix (diagonal block). + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam SrcL1TileT The actual tile type of the src matrix. + * @tparam DstL0TileT The actual tile type of the dst matrix. + * + * @param src Tile in L1 memory. + * @param dst Tile in L0A or L0B memory. + * @param block_size Size of diagonal blocks. Needs: block_size >= FractalSize. + * @param swap_parity If true, then the parity of copied blocks is swapped: left + * tile gets odd blocks, while right tile gets even blocks. This is used in the + * unrolled recursion part of the algorithm, where we need to copy alternating + * blocks of X in each iteration. + */ +template +AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, + uint32_t block_size, + bool swap_parity = false) { + constexpr bool is_left = + std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = + is_left ? SLayout::RowMajor : SLayout::ColMajor; + + // For left: copy even blocks 0, 2, 4, ... (starting_block=0) + // For right: copy odd blocks 1, 3, 5, ... (starting_block=1) + // Default: left→even(0), right→odd(1). swap_parity flips this. + const uint32_t starting_block_index = + (is_left ? 0u : 1u) ^ (swap_parity ? 1u : 0u); + + const uint32_t num_blocks = MatrixSize / block_size; + const uint32_t num_fractals_per_block = block_size / FractalSize; + + // might need fewer fractals if block_size < FractalSize + Tile + fractals[MatrixSize / FractalSize]; + + const std::uintptr_t starting_address = + reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < num_fractals_per_block; ++i) { + for (uint32_t j = 0; j < num_fractals_per_block; ++j) { + for (uint32_t b = starting_block_index; b < num_blocks; b += 2) { + const uint32_t offset = + b * (MatrixSize + FractalSize) * block_size /* block_offset */ + + i * MatrixSize * FractalSize /* col_fractal_offset */ + + j * FractalSize * FractalSize /* row_fractal_offset */; + TASSIGN(fractals[b], starting_address + offset * sizeof(InputT)); + TEXTRACT(fractals[b], src, b * block_size + i * FractalSize, + b * block_size + j * FractalSize); + } + } + } +} + +/** + * @brief: Prepares Identity and Zeros matrix. + * + * @tparam TileL1AB The type of the input tiles in L1. + * @tparam TileL0A The type of the input tiles in L0A. + * @tparam TileL0B The type of the input tiles in L0B. + * @tparam TileL0C The type of the input tiles in L0C. + * + * @param I_neg_l1_tile Tile containing the -I (negative identity) matrix. + * @param Zero_l1_tile Tile to store the all-zero matrix. + * @param I_l1_tile Tile to store the identity matrix. + * @param a_l0_tile Tile in L0A for matmuls. + * @param b_l0_tile Tile in L0B for matmuls. + * @param c_l0_tile Tile in L0C for matmuls. + */ +template +AICORE inline void PrepareAuxiliaryMatrices( + TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, + TileL0A a_l0_tile, TileL0B b_l0_tile, TileL0C c_l0_tile) { + TMOV(a_l0_tile, I_neg_l1_tile); // a_l0 initialized with I_neg + TMOV(b_l0_tile, I_neg_l1_tile); // b_l0 initialized with I_neg + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); // c_l0 contains I + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(I_l1_tile, c_l0_tile); // I_l1 now contains I + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + + TMOV(b_l0_tile, I_l1_tile); // b_l0 contains I + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL_ACC(c_l0_tile, c_l0_tile, a_l0_tile, + b_l0_tile); // c_l0 contains zeros + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(Zero_l1_tile, c_l0_tile); // Zeros_l1 now contains zeros + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); +} + +/** + * @brief: Inverts a single matrix / tile of the global tensor. + * The first part of the algorithm inverts the FractalSize * FractalSize + * diagonal blocks of the input matrix (inv_trick part). The second phase + * assembles the partial inverses using the cube unig (recursive part). + * + * @tparam InputT The type of the input elements. + * @tparam TileL1AB The type of the input tiles in L1. + * @tparam TileL0A The type of the input tiles in L0A. + * @tparam TileL0B The type of the input tiles in L0B. + * @tparam TileL0C The type of the input tiles in L0C. + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam FractalSize Size of matrix fractals. + * @tparam NumTilesPerCubeIter How many matrices to load and invert in a single + * cube iteration. + * + * @param X_l1_tile Tile in L1 used for intermediate computations. + * @param I_l1_tile Tile containing the identity matrix. + * @param I_neg_l1_tile Tile containing the negative identity matrix. + * @param M_neg_l1_tile Tile containing the negative input matrix. + * @param Zero_l1_tile Tile containing the all-zero matrix. + * @param Y_l1_tile Tile in L1 used for intermediate computations. + * @param a_l0_tile* Array of two tiles in L0A (for double-buffering). + * @param b_l0_tile* Array of two tiles in L0B (for double-buffering). + * @param c_l0_tile* Tile in L0C for matmuls. + * @param tile_id Index of the current tile (used for sync). + * @param swap_parity If true, then the parity of copied blocks is swapped: left + * tile gets odd blocks, while right tile gets even blocks. This is used in the + * unrolled recursion part of the algorithm, where we need to copy alternating + * blocks of X in each iteration. + */ +template +AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, + TileL1AB I_neg_l1_tile, + TileL1AB M_neg_l1_tile, + TileL1AB Zero_l1_tile, TileL1AB Y_l1_tile, + TileL0A* a_l0_tile, TileL0B* b_l0_tile, + TileL0C* c_l0_tile, const uint32_t tile_id, + const bool swap_parity = false) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); + + TMOV(b_l0_tile[0], Y_l1_tile); // b_l0[0] contains M + TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg + set_flag(PIPE_MTE1, PIPE_M, event_0); + TMOV(a_l0_tile[1], Zero_l1_tile); + TMOV(b_l0_tile[1], Zero_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + CopyDiagonalFractalsL1ToL0( + Y_l1_tile, a_l0_tile[1]); // a_l0[1] = diag_fractals(M) + CopyDiagonalFractalsL1ToL0( + Y_l1_tile, b_l0_tile[1]); // b_l0[1] = diag_fractals(M) + set_flag(PIPE_MTE1, PIPE_M, event_1); + + /* First Matmul: event_0 */ + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains M_neg + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(M_neg_l1_tile, c_l0_tile[0]); // M_neg_l1 now contains M_neg + set_flag(PIPE_FIX, PIPE_M, event_0); + + /* Second Matmul: event_1 */ + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], + b_l0_tile[1]); // c_l0[1] contains diag_fractals(M)^2 + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, + c_l0_tile[1]); // Y_l1 now contains diag_fractals(M)^2 + set_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + + /* Third Matmul: event_0*/ + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_neg_l1_tile); // b_l0[0] contains I_neg + TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[0], a_l0_tile[1], + b_l0_tile[0]); // c_l0[0] = diag_fractals(M_neg) + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], + b_l0_tile[0]); // c_l0[0] has I-diag_fractals(M) + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(X_l1_tile, c_l0_tile[0]); // X_l1 now contains I-diag_fractals(M) + + /* + * Inv Trick part: + * X = I - M + * Y = M + * block_size = 1 + * while block_size < FractalSize / 2: + * Y = Y @ Y + * X = X + X @ Y + * block_size *= 2 + */ + set_flag(PIPE_FIX, PIPE_M, event_0); // store c + set_flag(PIPE_M, PIPE_MTE1, event_0); // load matrices for matmuls + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y + set_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y + set_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y + for (uint32_t block_size = 1; block_size < FractalSize / 2; block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_l1_tile); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[0], X_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + TMOV(b_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_FIX, PIPE_M, event_0); // from previous iter + wait_flag(PIPE_MTE1, PIPE_M, event_0); // from loading a_l0[0], b_l0[0] + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains X + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + if (block_size < FractalSize / 4) { // Update Y except in last iteration + wait_flag(PIPE_M, PIPE_MTE1, event_1); // from previous iter + TMOV(a_l0_tile[1], Y_l1_tile); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); // from previous iter + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_1); // for next iter + set_flag(PIPE_M, PIPE_FIX, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); // for next iter + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); // for next iter + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], + b_l0_tile[1]); // c_l0[0] has X + X @ Y + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_FIX, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(X_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); // for next iter + set_flag(PIPE_FIX, PIPE_MTE1, event_0); // for next iter + } + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y + wait_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y + wait_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + /* + * Unrolled recursion part: + * block_size = FractalSize + * while block_size < MatrixSize: + * LX = even_blocks(X, block_size) + * RX = odd_blocks(X, block_size) + * Y = LX @ (-M) + I + * X = Y @ RX + LX + * block_size *= 2 + * + * Comments: + * Upper-tri (swap_parity=false): + * LX = even_blocks(X), RX = odd_blocks(X) + * Y = LX @ (-M) + I, X = Y @ RX + LX + * Lower-tri (swap_parity=true): + * RX = even→L0A(odd via swap), LX = odd→L0B(even via swap) + * Y = RX @ (-M) + I, X = Y @ LX + RX + */ + TMOV(b_l0_tile[1], M_neg_l1_tile); // b_l0[1] contains M_neg + TMOV(a_l0_tile[0], I_l1_tile); // a_l0[0] contains I + + if constexpr (MatrixSize > FractalSize) { + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + set_flag(PIPE_FIX, PIPE_M, event_0); + for (uint32_t block_size = FractalSize; block_size < MatrixSize; + block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for last iter a_l0[1] + TMOV(a_l0_tile[1], Zero_l1_tile); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Wait to write last X + CopyOddOrEvenBlocksL1ToL0( + X_l1_tile, a_l0_tile[1], block_size, + swap_parity); // a_l0[1]: even(LX) or odd(RX) + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); // Wait c_l0[0] from previous iter + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] has I + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); // Wait c_l0[1] from previous iter + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); // c_l0[1] contains LX + set_flag(PIPE_M, PIPE_MTE1, event_1); // allow to load RX on b_l0[0] + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], + b_l0_tile[1]); // c_l0[0] <- LX * M_neg + I + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(Y_l1_tile, c_l0_tile[0]); // Y_l1 contains LX * M_neg + I + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + + /* Load complementary blocks of X in L0B. If swap_parity = fase, "Load Odd + * Blocks Of X In L0B" */ + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], Zero_l1_tile); + CopyOddOrEvenBlocksL1ToL0( + X_l1_tile, b_l0_tile[0], block_size, + swap_parity); // b_l0[0]: odd(RX) or even(LX) + + wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for previous use of a_l0[1] + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); // Wait for Y_l1 + TMOV(a_l0_tile[1], Y_l1_tile); // a_l0[1] contains LX * M_neg + I + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_0); // next iter can read on a_l0[1] + set_flag(PIPE_M, PIPE_MTE1, event_1); // next iter can read on b_l0[0] + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + + if (block_size < MatrixSize / 2) { // Update X_l1 except in last iteration + TMOV(X_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); // release c_l0[1] for next iter + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + } + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Write c_l0[1] to X_l1 +} + +/** + * @brief: Runs the main kernel (inverts all matrices in the tensor) + * + * @tparam InputT The type of the input elements. Supports fp16 and bf16. + * @tparam OutputT The type of the output elements. + * @tparam MatrixSize Size of the entire input/output matrices. + * @tparam NumTilesPerCubeIter How many matrices to load and invert in a single + * cube iteration. + * @tparam IsBSND If IsBSND is false, then the last two dimensions represent a + * 2D triangular matrix in row-major format, while the other dimensions are + * batch dimensions. If IsBSND is true, then the dimensions represent in order: + * B batch size, S sequence length (which is chunked in tiles of size D), N + * number of heads (equivalent to a second batch dimension for this kernel), and + * D chunk size. The inverse is over the dimensions S (chunked) and D, row-major + * within each tile. + * + * @param M_inv pointer to the global memory to store the final inverse. + * @param M Pointer to the global tensor matrix in global memory. + * @param I_neg Pointer to global memory that contains the negative identity. + * @param total_tiles The total number of matrices to invert. + * @param num_bsnd_heads The number of heads, only for BSND format. + * @param is_lower If input matrices are lower-triangular (is_lower == 1) or + * upper-triangular (is_lower == 0). Default is upper triangular. + * @param num_bsnd_heads The number of heads, only for BSND format. + */ +template +AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, + __gm__ InputT* M, __gm__ InputT* I_neg, + uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + uint32_t is_lower = 0, + __gm__ int32_t* cu_seqlens = nullptr) { + /* Initializations */ + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; // fractal size for half /bf16 + constexpr uint32_t NumFractalsRowWise = MatrixSize / FractalSize; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { + return; + } + + using GlobalTileShapeIn = + TileShape2D; + using GlobalTileStridesIn = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileIn = + GlobalTensor; + using GlobalTileDynamicShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynamicStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileDynamicIn = GlobalTensor; + using GlobalTileStridesINeg = + BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = + TileShape2D; + using GlobalTileStridesOut = typename std::conditional< + !IsBSND, BaseShape2D, + Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileOut = GlobalTensor; + using GlobalTileDynamicOut = + GlobalTensor; + using TileL1AB = + Tile; + using TileL1ABDynamic = + Tile; + + // L0 Memory + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + using TileL0CDynamic = + TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(float)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], + c_l0_tile[0]); + + const uint32_t max_iters_per_aic = + CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); + + /* Main iteration - Compute all tiles */ + uint32_t bsnd_tile_offsets[NumTilesPerCubeIter] = {0}; + uint32_t bsnd_tile_valid_sizes[NumTilesPerCubeIter] = {0}; + uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; + set_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = + (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= total_tiles) { + break; + } + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + if constexpr (IsBSND) { + const uint32_t global_tile_id = global_index + tile_id; + if (cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = + kernel_utils::GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else { + bsnd_tile_offsets[tile_id] = kernel_utils::GetBSNDFixedTileOffset( + global_tile_id, num_bsnd_heads, MatrixSize); + bsnd_tile_valid_sizes[tile_id] = MatrixSize; + } + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + if (valid_size < MatrixSize) { + TileL1ABDynamic Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, + 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileDynamicIn M_global_in_dyn( + M + bsnd_offset, + {1, 1, 1, static_cast(valid_size), + static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } else { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + } else { + GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + TLOAD(Y_l1_tile[tile_id], + M_global_in); // Copies NumTilesPerCubeIter tiles at once + } + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + } + + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && + (global_index + tile_id < total_tiles); + ++tile_id) { + // Wait for previous cube iter to write result + wait_flag(PIPE_FIX, PIPE_M, static_cast(tile_id)); + // Wait for loading new matrices from GM + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, + Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id, + is_lower != 0); + + // Allow next cube_iter to proceed for this tile_id + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + + /* Store result */ + if constexpr (IsBSND) { + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + if (valid_size < MatrixSize) { + TileL0CDynamic c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, + 0x0 + final_c_buffer_index * TileLen * sizeof(float)); + GlobalTileDynamicOut M_inv_global_out_dyn( + M_inv + bsnd_offset, + {1, 1, 1, static_cast(valid_size), + static_cast(valid_size)}, + {1, 1, 1, row_stride, 1}); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } else { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + } else { + GlobalTileOut M_inv_global_out(M_inv + + (global_index + tile_id) * TileLen); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + next_tile_id_that_waits_for_pipe_fix_pipe_m = + (tile_id + 1) % NumTilesPerCubeIter; + set_flag( + PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + } + } + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + wait_flag(PIPE_FIX, PIPE_M, + static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); +} + +/* + * @brief: Computes the inverses of the blocks of tensor M + */ +template +AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, + __gm__ InputT* I_neg, uint32_t total_tiles, + uint32_t num_bsnd_heads = 0, + uint32_t is_lower = 0, + __gm__ int32_t* cu_seqlens = nullptr) { +#if defined(__DAV_CUBE__) // Cube compilation + + TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, + is_lower, cu_seqlens); +#else +// Nothing to do on AIV +#endif +} + +template +AICORE void run_tri_inv_rec_unroll(__gm__ OutputT* tensor_out, + __gm__ InputT* tensor_in, + __gm__ InputT* minus_eye_in, + uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, + uint32_t is_lower = 0, + __gm__ int32_t* cu_seqlens = nullptr) { + static_assert( + std::is_same_v or std::is_same_v, + "tri_inv_rec_unroll supports only fp16 or bf16."); + + static_assert( + std::is_same_v or std::is_same_v, + "tri_inv_rec_unroll supports only fp16 or bf16."); + switch (matrix_size) { + case 16: + runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, + num_matrices, num_bsnd_heads, is_lower, + cu_seqlens); + break; + case 32: + runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, + num_matrices, num_bsnd_heads, is_lower, + cu_seqlens); + break; + case 64: + runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, + num_matrices, num_bsnd_heads, is_lower, + cu_seqlens); + break; + case 128: + runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, + num_matrices, num_bsnd_heads, is_lower, + cu_seqlens); + break; + } +} + +template +AICORE void run_tri_inv_rec_unroll_per_num_matrices( + __gm__ OutputT* tensor_out, __gm__ InputT* tensor_in, + __gm__ InputT* minus_eye_in, uint32_t matrix_size, uint32_t num_matrices, + uint32_t num_bsnd_heads, uint32_t is_lower = 0, + __gm__ int32_t* cu_seqlens = nullptr) { + if (num_bsnd_heads == 0) { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } else { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } + } else { + if (num_matrices <= get_block_num()) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } else if (num_matrices <= 2 * get_block_num()) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } else { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, + num_bsnd_heads, is_lower, cu_seqlens); + } + } +} + +// --------------------------------------------------------------------------- +// simpler framework entry point +// --------------------------------------------------------------------------- + +extern "C" __aicore__ __attribute__((always_inline)) void +kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* M_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* I_neg_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ Tensor* M_inv_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ Tensor* config_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); + + __gm__ int64_t* cfg = + reinterpret_cast<__gm__ int64_t*>(config_tensor->buffer.addr); + const uint32_t matrix_size = static_cast(cfg[0]); + const uint32_t num_matrices = static_cast(cfg[1]); + const uint32_t is_lower = static_cast(cfg[2]); + + __gm__ half* M = reinterpret_cast<__gm__ half*>(M_tensor->buffer.addr) + + M_tensor->start_offset; + __gm__ half* I_neg = + reinterpret_cast<__gm__ half*>(I_neg_tensor->buffer.addr) + + I_neg_tensor->start_offset; + __gm__ half* M_inv = + reinterpret_cast<__gm__ half*>(M_inv_tensor->buffer.addr) + + M_inv_tensor->start_offset; + + run_tri_inv_rec_unroll_per_num_matrices( + M_inv, M, I_neg, matrix_size, num_matrices, 0, is_lower, nullptr); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp index d5001012e..f02ddbfae 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp @@ -9,22 +9,18 @@ * ----------------------------------------------------------------------------------------------------------- */ /** - * BGEMM Orchestration Function (tensormap_and_ringbuffer Runtime) + * Triangular Inverse Orchestration (tensormap_and_ringbuffer Runtime) * - * Builds the task graph for tiled matrix multiplication: C = A @ B + * Builds the task graph for batch triangular matrix inversion. * - * Configuration read from scalar args (set in golden.py): - * - tile_size: tile dimension (tile_size x tile_size per tile) - * - grid_k: number of K-dimension partitions - * - num_groups: number of independent groups (= matmul_add_task_num / grid_k) - * - incore_loop: number of tiles per group + * Arg layout (set in test_triangular_inverse.py): + * tensor(0) = M (INPUT) fp16 triangular matrices [num_matrices * N * N] + * tensor(1) = I_neg (INPUT) fp16 negative identity [N * N] + * tensor(2) = M_inv (OUTPUT) fp16 result [num_matrices * N * N] + * tensor(3) = config (INPUT) int64[3]: [matrix_size, num_matrices, is_lower] * - * Memory layout (tile-first, flattened): - * A: [num_groups, grid_k, incore_loop, tile_size, tile_size] - * B: [num_groups, grid_k, incore_loop, tile_size, tile_size] - * C: [incore_loop * num_groups, tile_size, tile_size] - * - * Arg layout: [A, B, C, config] + * The single AIC task (func_id=0) receives these four args in the same order + * and dispatches to run_tri_inv_rec_unroll_per_num_matrices. */ #include @@ -32,53 +28,41 @@ #include "pto_orchestration_api.h" // NOLINT(build/include_subdir) -#define FUNC_GEMM_TILE 0 -#define FUNC_TILE_ADD 1 +#define FUNC_TRI_INV 0 extern "C" { __attribute__((visibility("default"))) PTO2OrchestrationConfig -aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { - (void)orch_args; // NOLINT(readability/casting) +aicpu_orchestration_config(const ChipStorageTaskArgs& orch_args) { + (void)orch_args; return PTO2OrchestrationConfig{ .expected_arg_count = 4, }; } -__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { - // Tensor args - Tensor ext_A = from_tensor_arg(orch_args.tensor(0)); - Tensor ext_B = from_tensor_arg(orch_args.tensor(1)); - Tensor ext_C = from_tensor_arg(orch_args.tensor(2)); +__attribute__((visibility("default"))) void +aicpu_orchestration_entry(const ChipStorageTaskArgs& orch_args) { + Tensor ext_M = from_tensor_arg(orch_args.tensor(0)); + Tensor ext_I_neg = from_tensor_arg(orch_args.tensor(1)); + Tensor ext_M_inv = from_tensor_arg(orch_args.tensor(2)); Tensor ext_config = from_tensor_arg(orch_args.tensor(3)); - // Read config from tensor data: [tile_size, grid_k, num_groups, incore_loop] - int64_t *host_config = orch_args.tensor(3).data_as(); - int tile_size = static_cast(host_config[0]); - int grid_k = static_cast(host_config[1]); - int num_groups = static_cast(host_config[2]); - int incore_loop = static_cast(host_config[3]); - uint64_t tile_elems = static_cast(tile_size) * tile_size; - - int grid_m = 1; - int grid_n = 1; + int64_t* host_config = orch_args.tensor(3).data_as(); + int matrix_size = static_cast(host_config[0]); + int num_matrices = static_cast(host_config[1]); + int is_lower = static_cast(host_config[2]); LOG_INFO_V0( - "[bgemm_orch] tile_size: %d, grid_m: %d, grid_n: %d, grid_k: %d, num_groups: %d, incore_loop: %d", tile_size, - grid_m, grid_n, grid_k, num_groups, incore_loop + "[tri_inv_orch] matrix_size: %d, num_matrices: %d, is_lower: %d", + matrix_size, num_matrices, is_lower ); - uint32_t tile_shapes[1] = {static_cast(tile_elems)}; - uint64_t group_tile_elems = static_cast(incore_loop) * tile_elems; - uint32_t group_shapes[1] = {static_cast(group_tile_elems)}; - TensorCreateInfo group_ci(group_shapes, 1, DataType::FLOAT32); - - Arg params_gemm; - params_gemm.add_input(ext_A); - params_gemm.add_input(ext_B); - params_gemm.add_output(ext_C); - params_gemm.add_input(ext_config); - TaskOutputTensors gemm_outs = rt_submit_aic_task(FUNC_GEMM_TILE, params_gemm); + Arg params; + params.add_input(ext_M); + params.add_input(ext_I_neg); + params.add_output(ext_M_inv); + params.add_input(ext_config); + rt_submit_aic_task(FUNC_TRI_INV, params); } } // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index eff26efcd..1bdce9c04 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -7,7 +7,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""Test triangular inverse cube-only method: runtime-configurable C = torch.linalg.triangular_solve(A, torch.eye(A.shape[-1])).""" +"""Triangular inverse (recursive unrolled): M_inv = inv(M) for upper/lower-triangular fp16 matrices.""" import torch from simpler.task_interface import ArgDirection as D @@ -17,8 +17,9 @@ @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestTriangularInverse(SceneTestCase): - RTOL = 1e-3 - ATOL = 1e-3 + # fp16 arithmetic — use tolerances appropriate for half-precision results + RTOL = 1e-2 + ATOL = 1e-2 CALLABLE = { "orchestration": { @@ -29,55 +30,76 @@ class TestTriangularInverse(SceneTestCase): "incores": [ { "func_id": 0, - "name": "GEMM", - "source": "kernels/aic/kernel_simple_matmul.cpp", + "name": "TRI_INV", + "source": "kernels/aic/kernel_tri_inv_rec_unroll.cpp", "core_type": "aic", - "signature": [D.IN, D.IN, D.OUT], + "signature": [D.IN, D.IN, D.OUT, D.IN], } ], } CASES = [ { - "name": "Case0", + "name": "Case0_upper32", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"batch_dim": 4, "incore_data_size": 128}, + "params": {"num_matrices": 4, "matrix_size": 32, "is_lower": 0}, }, { - "name": "Case1", + "name": "Case1_upper64", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"batch_dim": 4, "incore_data_size": 128}, + "params": {"num_matrices": 4, "matrix_size": 64, "is_lower": 0}, }, { - "name": "Case2", + "name": "Case2_lower32", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"batch_dim": 4, "incore_data_size": 128}, + "params": {"num_matrices": 4, "matrix_size": 32, "is_lower": 1}, }, ] def generate_args(self, params): - tile_size = params["incore_data_size"] - batch_dim = params["batch_dim"] - A = torch.randn(batch_dim, tile_size, tile_size, dtype=torch.float32) * 0.1 - B = torch.randn(batch_dim, tile_size, tile_size, dtype=torch.float32) * 0.1 - C = torch.zeros(batch_dim, tile_size, tile_size, dtype=torch.float32) - config = torch.tensor([tile_size, batch_dim], dtype=torch.int64) + n = params["matrix_size"] + num_matrices = params["num_matrices"] + is_lower = params["is_lower"] + + # Build well-conditioned triangular matrices in fp16. + # Start with random values and zero out the off-triangle, then set + # the diagonal to a value in [0.5, 1.5] to ensure invertibility. + M_f32 = torch.rand(num_matrices, n, n, dtype=torch.float32) + if is_lower: + M_f32 = torch.tril(M_f32) + else: + M_f32 = torch.triu(M_f32) + # Diagonal entries in [0.5, 1.5] — keeps condition number reasonable + diag_vals = torch.rand(num_matrices, n, dtype=torch.float32) + 0.5 + idx = torch.arange(n) + M_f32[:, idx, idx] = diag_vals + M = M_f32.to(torch.float16) + + # Negative identity matrix — used by the kernel to derive I and Zero + I_neg = (-torch.eye(n, dtype=torch.float16)).flatten() + + M_inv = torch.zeros(num_matrices, n, n, dtype=torch.float16) + config = torch.tensor([n, num_matrices, is_lower], dtype=torch.int64) + return TaskArgsBuilder( - Tensor("A", A.flatten()), Tensor("B", B.flatten()), Tensor("C", C.flatten()), Tensor("config", config) + Tensor("M", M.flatten()), + Tensor("I_neg", I_neg), + Tensor("M_inv", M_inv.flatten()), + Tensor("config", config), ) def compute_golden(self, args, params): - tile_size = params["incore_data_size"] - batch_dim = params["batch_dim"] - A = args.A.reshape(batch_dim, tile_size, tile_size) - B = args.B.reshape(batch_dim, tile_size, tile_size) - C = A @ B - return C + n = params["matrix_size"] + num_matrices = params["num_matrices"] + M = args.M.reshape(num_matrices, n, n).to(torch.float32) + # Compute in float32 for numerical stability, then cast to fp16 + M_inv = torch.linalg.inv(M).to(torch.float16) + return M_inv if __name__ == "__main__": From f18f90e043dac9a13b7942391a70141d41a473b3 Mon Sep 17 00:00:00 2001 From: anastasios Date: Wed, 20 May 2026 06:13:12 +0000 Subject: [PATCH 06/16] fix --- .../kernels/aic/kernel_tri_inv_rec_unroll.cpp | 1262 ++++++++--------- .../orchestration/triangular_inverse_orch.cpp | 24 +- .../test_triangular_inverse.py | 69 +- 3 files changed, 651 insertions(+), 704 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp index cad69147c..19471ccfb 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp @@ -10,7 +10,7 @@ for the full License text. * Triangular matrix inverse kernel (recursive unrolled algorithm). * * Adapted from pto-kernels for the simpler framework: - * - kernel_utils.h content is inlined below + * - tri_inv_utils.h content is inlined below * - constants.h (static -I matrices) is replaced by a tensor arg passed * from the Python test via the orchestration * - The pto-kernels __global__ entry points are replaced by a single @@ -24,6 +24,7 @@ for the full License text. */ #include +#include #include #include "tensor.h" @@ -31,7 +32,7 @@ for the full License text. using namespace pto; #ifndef GM_ADDR -#define GM_ADDR __gm__ uint8_t* +#define GM_ADDR __gm__ uint8_t * #endif #ifndef __gm__ @@ -42,67 +43,58 @@ using namespace pto; #define __aicore__ [aicore] #endif -// --------------------------------------------------------------------------- -// Inlined content of kernel_utils.h (from pto-kernels) -// --------------------------------------------------------------------------- -namespace kernel_utils { +namespace tri_inv_utils { template AICORE inline void SetWaitFlag(uint32_t id) { - set_flag(SrcPipe, DstPipe, static_cast(id)); - wait_flag(SrcPipe, DstPipe, static_cast(id)); + set_flag(SrcPipe, DstPipe, static_cast(id)); + wait_flag(SrcPipe, DstPipe, static_cast(id)); } -template ::value && - std::is_integral::value, - int>::type = 0> +template < + typename T1, typename T2, + typename std::enable_if::value && std::is_integral::value, int>::type = 0> AICORE inline T1 CeilDiv(T1 value, T2 divisor) { - return (value + divisor - 1) / divisor; + return (value + divisor - 1) / divisor; } -#define BSND_OFFSET(tile_id, N, S, D) \ - (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) +#define BSND_OFFSET(tile_id, N, S, D) (((tile_id) / (N)) * (S) * (N) * (D) + ((tile_id) % (N)) * (D)) -AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, - uint32_t num_bsnd_heads, - uint32_t matrix_size) { - return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); +AICORE inline uint32_t GetBSNDFixedTileOffset(uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size) { + return BSND_OFFSET(tile_id, num_bsnd_heads, matrix_size, matrix_size); } struct BSNDVarlenTileInfo { - uint32_t bsnd_offset; - uint32_t valid_size; + uint32_t bsnd_offset; + uint32_t valid_size; }; AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( - uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, - __gm__ int32_t* cu_seqlens) { - const uint32_t head_idx = tile_id % num_bsnd_heads; - const uint32_t chunk_idx = tile_id / num_bsnd_heads; - - uint32_t seq_start = static_cast(cu_seqlens[0]); - uint32_t accumulated_chunks = 0; - for (uint32_t seq_idx = 0;; ++seq_idx) { - const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); - const uint32_t seq_len = seq_end - seq_start; - const uint32_t seq_num_chunks = CeilDiv(seq_len, matrix_size); - if (chunk_idx < accumulated_chunks + seq_num_chunks) { - const uint32_t local_chunk_idx = chunk_idx - accumulated_chunks; - const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; - const uint32_t valid_size = - min(static_cast(seq_end - row_start), matrix_size); - return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, - valid_size}; + uint32_t tile_id, uint32_t num_bsnd_heads, uint32_t matrix_size, __gm__ int32_t *cu_seqlens +) { + const uint32_t head_idx = tile_id % num_bsnd_heads; + const uint32_t chunk_idx = tile_id / num_bsnd_heads; + + uint32_t seq_start = static_cast(cu_seqlens[0]); + uint32_t accumulated_chunks = 0; + for (uint32_t seq_idx = 0;; ++seq_idx) { + const uint32_t seq_end = static_cast(cu_seqlens[seq_idx + 1]); + const uint32_t seq_len = seq_end - seq_start; + const uint32_t seq_num_chunks = CeilDiv(seq_len, matrix_size); + if (chunk_idx < accumulated_chunks + seq_num_chunks) { + const uint32_t local_chunk_idx = chunk_idx - accumulated_chunks; + const uint32_t row_start = seq_start + local_chunk_idx * matrix_size; + const uint32_t valid_size = min(static_cast(seq_end - row_start), matrix_size); + return {row_start * num_bsnd_heads * matrix_size + head_idx * matrix_size, valid_size}; + } + accumulated_chunks += seq_num_chunks; + seq_start = seq_end; } - accumulated_chunks += seq_num_chunks; - seq_start = seq_end; - } } -} // namespace kernel_utils +} // namespace tri_inv_utils -using namespace kernel_utils; +using namespace tri_inv_utils; // --------------------------------------------------------------------------- // Kernel template code (verbatim from pto-kernels kernel_tri_inv_rec_unroll.cpp) @@ -123,27 +115,22 @@ using namespace kernel_utils; * @param src Tile in L1 memory. * @param dst Tile in L0A or L0B memory. */ -template +template AICORE inline void CopyDiagonalFractalsL1ToL0(SrcL1TileT src, DstL0TileT dst) { - constexpr uint32_t NumFractals = MatrixSize / FractalSize; - constexpr bool is_left = - std::is_same_v>; - constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; - constexpr SLayout InnerLayout = - is_left ? SLayout::RowMajor : SLayout::ColMajor; - - Tile - fractals[NumFractals]; - const std::uintptr_t starting_address = - reinterpret_cast(dst.data()); - for (uint32_t i = 0; i < NumFractals; ++i) { - TASSIGN(fractals[i], starting_address + i * FractalSize * - (MatrixSize + FractalSize) * - sizeof(InputT)); - TEXTRACT(fractals[i], src, i * FractalSize, i * FractalSize); - } + constexpr uint32_t NumFractals = MatrixSize / FractalSize; + constexpr bool is_left = std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = is_left ? SLayout::RowMajor : SLayout::ColMajor; + + Tile< + LeftOrRight, InputT, FractalSize, FractalSize, BLayout::RowMajor, FractalSize, FractalSize, InnerLayout, + TileConfig::fractalABSize> + fractals[NumFractals]; + const std::uintptr_t starting_address = reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < NumFractals; ++i) { + TASSIGN(fractals[i], starting_address + i * FractalSize * (MatrixSize + FractalSize) * sizeof(InputT)); + TEXTRACT(fractals[i], src, i * FractalSize, i * FractalSize); + } } /** @@ -170,46 +157,39 @@ AICORE inline void CopyDiagonalFractalsL1ToL0(SrcL1TileT src, DstL0TileT dst) { * unrolled recursion part of the algorithm, where we need to copy alternating * blocks of X in each iteration. */ -template -AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, - uint32_t block_size, - bool swap_parity = false) { - constexpr bool is_left = - std::is_same_v>; - constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; - constexpr SLayout InnerLayout = - is_left ? SLayout::RowMajor : SLayout::ColMajor; - - // For left: copy even blocks 0, 2, 4, ... (starting_block=0) - // For right: copy odd blocks 1, 3, 5, ... (starting_block=1) - // Default: left→even(0), right→odd(1). swap_parity flips this. - const uint32_t starting_block_index = - (is_left ? 0u : 1u) ^ (swap_parity ? 1u : 0u); - - const uint32_t num_blocks = MatrixSize / block_size; - const uint32_t num_fractals_per_block = block_size / FractalSize; - - // might need fewer fractals if block_size < FractalSize - Tile - fractals[MatrixSize / FractalSize]; - - const std::uintptr_t starting_address = - reinterpret_cast(dst.data()); - for (uint32_t i = 0; i < num_fractals_per_block; ++i) { - for (uint32_t j = 0; j < num_fractals_per_block; ++j) { - for (uint32_t b = starting_block_index; b < num_blocks; b += 2) { - const uint32_t offset = - b * (MatrixSize + FractalSize) * block_size /* block_offset */ + - i * MatrixSize * FractalSize /* col_fractal_offset */ + - j * FractalSize * FractalSize /* row_fractal_offset */; - TASSIGN(fractals[b], starting_address + offset * sizeof(InputT)); - TEXTRACT(fractals[b], src, b * block_size + i * FractalSize, - b * block_size + j * FractalSize); - } +template +AICORE inline void +CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, uint32_t block_size, bool swap_parity = false) { + constexpr bool is_left = std::is_same_v>; + constexpr TileType LeftOrRight = is_left ? TileType::Left : TileType::Right; + constexpr SLayout InnerLayout = is_left ? SLayout::RowMajor : SLayout::ColMajor; + + // For left: copy even blocks 0, 2, 4, ... (starting_block=0) + // For right: copy odd blocks 1, 3, 5, ... (starting_block=1) + // Default: left→even(0), right→odd(1). swap_parity flips this. + const uint32_t starting_block_index = (is_left ? 0u : 1u) ^ (swap_parity ? 1u : 0u); + + const uint32_t num_blocks = MatrixSize / block_size; + const uint32_t num_fractals_per_block = block_size / FractalSize; + + // might need fewer fractals if block_size < FractalSize + Tile< + LeftOrRight, InputT, FractalSize, FractalSize, BLayout::RowMajor, FractalSize, FractalSize, InnerLayout, + TileConfig::fractalABSize> + fractals[MatrixSize / FractalSize]; + + const std::uintptr_t starting_address = reinterpret_cast(dst.data()); + for (uint32_t i = 0; i < num_fractals_per_block; ++i) { + for (uint32_t j = 0; j < num_fractals_per_block; ++j) { + for (uint32_t b = starting_block_index; b < num_blocks; b += 2) { + const uint32_t offset = b * (MatrixSize + FractalSize) * block_size /* block_offset */ + + i * MatrixSize * FractalSize /* col_fractal_offset */ + + j * FractalSize * FractalSize /* row_fractal_offset */; + TASSIGN(fractals[b], starting_address + offset * sizeof(InputT)); + TEXTRACT(fractals[b], src, b * block_size + i * FractalSize, b * block_size + j * FractalSize); + } + } } - } } /** @@ -227,36 +207,36 @@ AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, * @param b_l0_tile Tile in L0B for matmuls. * @param c_l0_tile Tile in L0C for matmuls. */ -template +template AICORE inline void PrepareAuxiliaryMatrices( - TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, - TileL0A a_l0_tile, TileL0B b_l0_tile, TileL0C c_l0_tile) { - TMOV(a_l0_tile, I_neg_l1_tile); // a_l0 initialized with I_neg - TMOV(b_l0_tile, I_neg_l1_tile); // b_l0 initialized with I_neg - set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); - wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); - - TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); // c_l0 contains I - set_flag(PIPE_M, PIPE_FIX, static_cast(0)); - wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); - - TMOV(I_l1_tile, c_l0_tile); // I_l1 now contains I - set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); - wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); - - TMOV(b_l0_tile, I_l1_tile); // b_l0 contains I - set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); - wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); - - TMATMUL_ACC(c_l0_tile, c_l0_tile, a_l0_tile, - b_l0_tile); // c_l0 contains zeros - set_flag(PIPE_M, PIPE_FIX, static_cast(0)); - wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); - - TMOV(Zero_l1_tile, c_l0_tile); // Zeros_l1 now contains zeros - set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); - wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, TileL0A a_l0_tile, TileL0B b_l0_tile, + TileL0C c_l0_tile +) { + TMOV(a_l0_tile, I_neg_l1_tile); // a_l0 initialized with I_neg + TMOV(b_l0_tile, I_neg_l1_tile); // b_l0 initialized with I_neg + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL(c_l0_tile, a_l0_tile, b_l0_tile); // c_l0 contains I + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(I_l1_tile, c_l0_tile); // I_l1 now contains I + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + + TMOV(b_l0_tile, I_l1_tile); // b_l0 contains I + set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + wait_flag(PIPE_MTE1, PIPE_M, static_cast(0)); + + TMATMUL_ACC(c_l0_tile, c_l0_tile, a_l0_tile, + b_l0_tile); // c_l0 contains zeros + set_flag(PIPE_M, PIPE_FIX, static_cast(0)); + wait_flag(PIPE_M, PIPE_FIX, static_cast(0)); + + TMOV(Zero_l1_tile, c_l0_tile); // Zeros_l1 now contains zeros + set_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_FIX, PIPE_MTE1, static_cast(0)); } /** @@ -290,242 +270,239 @@ AICORE inline void PrepareAuxiliaryMatrices( * unrolled recursion part of the algorithm, where we need to copy alternating * blocks of X in each iteration. */ -template -AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, - TileL1AB I_neg_l1_tile, - TileL1AB M_neg_l1_tile, - TileL1AB Zero_l1_tile, TileL1AB Y_l1_tile, - TileL0A* a_l0_tile, TileL0B* b_l0_tile, - TileL0C* c_l0_tile, const uint32_t tile_id, - const bool swap_parity = false) { - const event_t event_0 = static_cast(tile_id); - const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); - - TMOV(b_l0_tile[0], Y_l1_tile); // b_l0[0] contains M - TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg - set_flag(PIPE_MTE1, PIPE_M, event_0); - TMOV(a_l0_tile[1], Zero_l1_tile); - TMOV(b_l0_tile[1], Zero_l1_tile); - set_flag(PIPE_MTE1, PIPE_M, event_1); - wait_flag(PIPE_MTE1, PIPE_M, event_1); - set_flag(PIPE_M, PIPE_MTE1, event_1); - wait_flag(PIPE_M, PIPE_MTE1, event_1); - CopyDiagonalFractalsL1ToL0( - Y_l1_tile, a_l0_tile[1]); // a_l0[1] = diag_fractals(M) - CopyDiagonalFractalsL1ToL0( - Y_l1_tile, b_l0_tile[1]); // b_l0[1] = diag_fractals(M) - set_flag(PIPE_MTE1, PIPE_M, event_1); - - /* First Matmul: event_0 */ - wait_flag(PIPE_MTE1, PIPE_M, event_0); - TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains M_neg - set_flag(PIPE_M, PIPE_FIX, event_0); - set_flag(PIPE_M, PIPE_MTE1, event_0); - - wait_flag(PIPE_M, PIPE_FIX, event_0); - TMOV(M_neg_l1_tile, c_l0_tile[0]); // M_neg_l1 now contains M_neg - set_flag(PIPE_FIX, PIPE_M, event_0); - - /* Second Matmul: event_1 */ - wait_flag(PIPE_MTE1, PIPE_M, event_1); - set_flag(PIPE_MTE1, PIPE_M, event_1); - TMATMUL(c_l0_tile[1], a_l0_tile[1], - b_l0_tile[1]); // c_l0[1] contains diag_fractals(M)^2 - set_flag(PIPE_M, PIPE_FIX, event_1); - wait_flag(PIPE_M, PIPE_FIX, event_1); - TMOV(Y_l1_tile, - c_l0_tile[1]); // Y_l1 now contains diag_fractals(M)^2 - set_flag(PIPE_FIX, PIPE_M, event_1); - wait_flag(PIPE_FIX, PIPE_M, event_1); - - /* Third Matmul: event_0*/ - wait_flag(PIPE_M, PIPE_MTE1, event_0); - TMOV(b_l0_tile[0], I_neg_l1_tile); // b_l0[0] contains I_neg - TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg - set_flag(PIPE_MTE1, PIPE_M, event_0); - - wait_flag(PIPE_MTE1, PIPE_M, event_0); - wait_flag(PIPE_FIX, PIPE_M, event_0); - wait_flag(PIPE_MTE1, PIPE_M, event_1); - TMATMUL(c_l0_tile[0], a_l0_tile[1], - b_l0_tile[0]); // c_l0[0] = diag_fractals(M_neg) - set_flag(PIPE_M, PIPE_FIX, event_0); - wait_flag(PIPE_M, PIPE_FIX, event_0); - set_flag(PIPE_FIX, PIPE_M, event_0); - wait_flag(PIPE_FIX, PIPE_M, event_0); - - TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], - b_l0_tile[0]); // c_l0[0] has I-diag_fractals(M) - set_flag(PIPE_M, PIPE_FIX, event_1); - wait_flag(PIPE_M, PIPE_FIX, event_1); - TMOV(X_l1_tile, c_l0_tile[0]); // X_l1 now contains I-diag_fractals(M) - - /* - * Inv Trick part: - * X = I - M - * Y = M - * block_size = 1 - * while block_size < FractalSize / 2: - * Y = Y @ Y - * X = X + X @ Y - * block_size *= 2 - */ - set_flag(PIPE_FIX, PIPE_M, event_0); // store c - set_flag(PIPE_M, PIPE_MTE1, event_0); // load matrices for matmuls - set_flag(PIPE_FIX, PIPE_MTE1, event_0); - set_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y - set_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y - set_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y - for (uint32_t block_size = 1; block_size < FractalSize / 2; block_size *= 2) { - wait_flag(PIPE_M, PIPE_MTE1, event_0); - TMOV(b_l0_tile[0], I_l1_tile); - wait_flag(PIPE_FIX, PIPE_MTE1, event_0); - TMOV(a_l0_tile[0], X_l1_tile); +template < + typename InputT, typename TileL1AB, typename TileL0A, typename TileL0B, typename TileL0C, uint32_t MatrixSize, + uint32_t FractalSize, uint32_t NumTilesPerCubeIter> +AICORE inline void InvertSingleTile( + TileL1AB X_l1_tile, TileL1AB I_l1_tile, TileL1AB I_neg_l1_tile, TileL1AB M_neg_l1_tile, TileL1AB Zero_l1_tile, + TileL1AB Y_l1_tile, TileL0A *a_l0_tile, TileL0B *b_l0_tile, TileL0C *c_l0_tile, const uint32_t tile_id, + const bool swap_parity = false +) { + const event_t event_0 = static_cast(tile_id); + const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); + + TMOV(b_l0_tile[0], Y_l1_tile); // b_l0[0] contains M + TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg set_flag(PIPE_MTE1, PIPE_M, event_0); - - wait_flag(PIPE_FIX, PIPE_MTE1, event_1); - TMOV(b_l0_tile[1], Y_l1_tile); + TMOV(a_l0_tile[1], Zero_l1_tile); + TMOV(b_l0_tile[1], Zero_l1_tile); set_flag(PIPE_MTE1, PIPE_M, event_1); - - wait_flag(PIPE_FIX, PIPE_M, event_0); // from previous iter - wait_flag(PIPE_MTE1, PIPE_M, event_0); // from loading a_l0[0], b_l0[0] - TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains X - set_flag(PIPE_M, PIPE_FIX, event_0); - wait_flag(PIPE_M, PIPE_FIX, event_0); - set_flag(PIPE_FIX, PIPE_M, event_0); - wait_flag(PIPE_FIX, PIPE_M, event_0); - - if (block_size < FractalSize / 4) { // Update Y except in last iteration - wait_flag(PIPE_M, PIPE_MTE1, event_1); // from previous iter - TMOV(a_l0_tile[1], Y_l1_tile); - wait_flag(PIPE_MTE1, PIPE_M, event_1); - set_flag(PIPE_MTE1, PIPE_M, event_1); - - wait_flag(PIPE_MTE1, PIPE_M, event_1); - wait_flag(PIPE_FIX, PIPE_M, event_1); // from previous iter - TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); - set_flag(PIPE_M, PIPE_MTE1, event_1); // for next iter - set_flag(PIPE_M, PIPE_FIX, event_1); - set_flag(PIPE_MTE1, PIPE_M, event_1); - - wait_flag(PIPE_M, PIPE_FIX, event_1); - TMOV(Y_l1_tile, c_l0_tile[1]); - set_flag(PIPE_FIX, PIPE_M, event_1); // for next iter - } - set_flag(PIPE_FIX, PIPE_MTE1, event_1); // for next iter - wait_flag(PIPE_MTE1, PIPE_M, event_1); - TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], - b_l0_tile[1]); // c_l0[0] has X + X @ Y - set_flag(PIPE_M, PIPE_MTE1, event_0); - set_flag(PIPE_M, PIPE_FIX, event_0); - - wait_flag(PIPE_M, PIPE_FIX, event_0); - TMOV(X_l1_tile, c_l0_tile[0]); - set_flag(PIPE_FIX, PIPE_M, event_0); // for next iter - set_flag(PIPE_FIX, PIPE_MTE1, event_0); // for next iter - } - wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y - wait_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y - wait_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y - wait_flag(PIPE_FIX, PIPE_MTE1, event_0); - wait_flag(PIPE_M, PIPE_MTE1, event_0); - wait_flag(PIPE_FIX, PIPE_M, event_0); - - /* - * Unrolled recursion part: - * block_size = FractalSize - * while block_size < MatrixSize: - * LX = even_blocks(X, block_size) - * RX = odd_blocks(X, block_size) - * Y = LX @ (-M) + I - * X = Y @ RX + LX - * block_size *= 2 - * - * Comments: - * Upper-tri (swap_parity=false): - * LX = even_blocks(X), RX = odd_blocks(X) - * Y = LX @ (-M) + I, X = Y @ RX + LX - * Lower-tri (swap_parity=true): - * RX = even→L0A(odd via swap), LX = odd→L0B(even via swap) - * Y = RX @ (-M) + I, X = Y @ LX + RX - */ - TMOV(b_l0_tile[1], M_neg_l1_tile); // b_l0[1] contains M_neg - TMOV(a_l0_tile[0], I_l1_tile); // a_l0[0] contains I - - if constexpr (MatrixSize > FractalSize) { - set_flag(PIPE_FIX, PIPE_M, event_1); - } - set_flag(PIPE_M, PIPE_MTE1, event_1); - set_flag(PIPE_M, PIPE_MTE1, event_0); - set_flag(PIPE_FIX, PIPE_MTE1, event_1); - set_flag(PIPE_FIX, PIPE_M, event_0); - for (uint32_t block_size = FractalSize; block_size < MatrixSize; - block_size *= 2) { - wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for last iter a_l0[1] - TMOV(a_l0_tile[1], Zero_l1_tile); - + set_flag(PIPE_M, PIPE_MTE1, event_1); wait_flag(PIPE_M, PIPE_MTE1, event_1); - TMOV(b_l0_tile[0], I_l1_tile); - set_flag(PIPE_MTE1, PIPE_M, event_0); - - wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Wait to write last X - CopyOddOrEvenBlocksL1ToL0( - X_l1_tile, a_l0_tile[1], block_size, - swap_parity); // a_l0[1]: even(LX) or odd(RX) + CopyDiagonalFractalsL1ToL0(Y_l1_tile, a_l0_tile[1]); // a_l0[1] = diag_fractals(M) + CopyDiagonalFractalsL1ToL0(Y_l1_tile, b_l0_tile[1]); // b_l0[1] = diag_fractals(M) set_flag(PIPE_MTE1, PIPE_M, event_1); + /* First Matmul: event_0 */ wait_flag(PIPE_MTE1, PIPE_M, event_0); - wait_flag(PIPE_FIX, PIPE_M, event_0); // Wait c_l0[0] from previous iter - TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] has I - - wait_flag(PIPE_MTE1, PIPE_M, event_1); - wait_flag(PIPE_FIX, PIPE_M, event_1); // Wait c_l0[1] from previous iter - TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); // c_l0[1] contains LX - set_flag(PIPE_M, PIPE_MTE1, event_1); // allow to load RX on b_l0[0] - - TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], - b_l0_tile[1]); // c_l0[0] <- LX * M_neg + I + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains M_neg set_flag(PIPE_M, PIPE_FIX, event_0); set_flag(PIPE_M, PIPE_MTE1, event_0); wait_flag(PIPE_M, PIPE_FIX, event_0); - TMOV(Y_l1_tile, c_l0_tile[0]); // Y_l1 contains LX * M_neg + I - set_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(M_neg_l1_tile, c_l0_tile[0]); // M_neg_l1 now contains M_neg set_flag(PIPE_FIX, PIPE_M, event_0); - /* Load complementary blocks of X in L0B. If swap_parity = fase, "Load Odd - * Blocks Of X In L0B" */ - wait_flag(PIPE_M, PIPE_MTE1, event_1); - TMOV(b_l0_tile[0], Zero_l1_tile); - CopyOddOrEvenBlocksL1ToL0( - X_l1_tile, b_l0_tile[0], block_size, - swap_parity); // b_l0[0]: odd(RX) or even(LX) - - wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for previous use of a_l0[1] - wait_flag(PIPE_FIX, PIPE_MTE1, event_0); // Wait for Y_l1 - TMOV(a_l0_tile[1], Y_l1_tile); // a_l0[1] contains LX * M_neg + I + /* Second Matmul: event_1 */ + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[1], a_l0_tile[1], + b_l0_tile[1]); // c_l0[1] contains diag_fractals(M)^2 + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, + c_l0_tile[1]); // Y_l1 now contains diag_fractals(M)^2 + set_flag(PIPE_FIX, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); + + /* Third Matmul: event_0*/ + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_neg_l1_tile); // b_l0[0] contains I_neg + TMOV(a_l0_tile[0], I_neg_l1_tile); // a_l0[0] contains I_neg set_flag(PIPE_MTE1, PIPE_M, event_0); wait_flag(PIPE_MTE1, PIPE_M, event_0); - TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); - set_flag(PIPE_M, PIPE_MTE1, event_0); // next iter can read on a_l0[1] - set_flag(PIPE_M, PIPE_MTE1, event_1); // next iter can read on b_l0[0] + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL(c_l0_tile[0], a_l0_tile[1], + b_l0_tile[0]); // c_l0[0] = diag_fractals(M_neg) set_flag(PIPE_M, PIPE_FIX, event_0); wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); - if (block_size < MatrixSize / 2) { // Update X_l1 except in last iteration - TMOV(X_l1_tile, c_l0_tile[1]); - set_flag(PIPE_FIX, PIPE_M, event_1); // release c_l0[1] for next iter + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], + b_l0_tile[0]); // c_l0[0] has I-diag_fractals(M) + set_flag(PIPE_M, PIPE_FIX, event_1); + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(X_l1_tile, c_l0_tile[0]); // X_l1 now contains I-diag_fractals(M) + + /* + * Inv Trick part: + * X = I - M + * Y = M + * block_size = 1 + * while block_size < FractalSize / 2: + * Y = Y @ Y + * X = X + X @ Y + * block_size *= 2 + */ + set_flag(PIPE_FIX, PIPE_M, event_0); // store c + set_flag(PIPE_M, PIPE_MTE1, event_0); // load matrices for matmuls + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y + set_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y + set_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y + for (uint32_t block_size = 1; block_size < FractalSize / 2; block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); + TMOV(b_l0_tile[0], I_l1_tile); + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + TMOV(a_l0_tile[0], X_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); + TMOV(b_l0_tile[1], Y_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_FIX, PIPE_M, event_0); // from previous iter + wait_flag(PIPE_MTE1, PIPE_M, event_0); // from loading a_l0[0], b_l0[0] + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] contains X + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + if (block_size < FractalSize / 4) { // Update Y except in last iteration + wait_flag(PIPE_M, PIPE_MTE1, event_1); // from previous iter + TMOV(a_l0_tile[1], Y_l1_tile); + wait_flag(PIPE_MTE1, PIPE_M, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); // from previous iter + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[1]); + set_flag(PIPE_M, PIPE_MTE1, event_1); // for next iter + set_flag(PIPE_M, PIPE_FIX, event_1); + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_M, PIPE_FIX, event_1); + TMOV(Y_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); // for next iter + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); // for next iter + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[0], + b_l0_tile[1]); // c_l0[0] has X + X @ Y + set_flag(PIPE_M, PIPE_MTE1, event_0); + set_flag(PIPE_M, PIPE_FIX, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(X_l1_tile, c_l0_tile[0]); + set_flag(PIPE_FIX, PIPE_M, event_0); // for next iter + set_flag(PIPE_FIX, PIPE_MTE1, event_0); // for next iter } + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // only for update Y + wait_flag(PIPE_M, PIPE_MTE1, event_1); // only for update Y + wait_flag(PIPE_FIX, PIPE_M, event_1); // only for update Y + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); + + /* + * Unrolled recursion part: + * block_size = FractalSize + * while block_size < MatrixSize: + * LX = even_blocks(X, block_size) + * RX = odd_blocks(X, block_size) + * Y = LX @ (-M) + I + * X = Y @ RX + LX + * block_size *= 2 + * + * Comments: + * Upper-tri (swap_parity=false): + * LX = even_blocks(X), RX = odd_blocks(X) + * Y = LX @ (-M) + I, X = Y @ RX + LX + * Lower-tri (swap_parity=true): + * RX = even→L0A(odd via swap), LX = odd→L0B(even via swap) + * Y = RX @ (-M) + I, X = Y @ LX + RX + */ + TMOV(b_l0_tile[1], M_neg_l1_tile); // b_l0[1] contains M_neg + TMOV(a_l0_tile[0], I_l1_tile); // a_l0[0] contains I + + if constexpr (MatrixSize > FractalSize) { + set_flag(PIPE_FIX, PIPE_M, event_1); + } + set_flag(PIPE_M, PIPE_MTE1, event_1); + set_flag(PIPE_M, PIPE_MTE1, event_0); set_flag(PIPE_FIX, PIPE_MTE1, event_1); - } - wait_flag(PIPE_M, PIPE_MTE1, event_0); - wait_flag(PIPE_M, PIPE_MTE1, event_1); - wait_flag(PIPE_FIX, PIPE_M, event_0); - wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Write c_l0[1] to X_l1 + set_flag(PIPE_FIX, PIPE_M, event_0); + for (uint32_t block_size = FractalSize; block_size < MatrixSize; block_size *= 2) { + wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for last iter a_l0[1] + TMOV(a_l0_tile[1], Zero_l1_tile); + + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], I_l1_tile); + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Wait to write last X + CopyOddOrEvenBlocksL1ToL0( + X_l1_tile, a_l0_tile[1], block_size, + swap_parity + ); // a_l0[1]: even(LX) or odd(RX) + set_flag(PIPE_MTE1, PIPE_M, event_1); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_M, event_0); // Wait c_l0[0] from previous iter + TMATMUL(c_l0_tile[0], a_l0_tile[0], b_l0_tile[0]); // c_l0[0] has I + + wait_flag(PIPE_MTE1, PIPE_M, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_1); // Wait c_l0[1] from previous iter + TMATMUL(c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); // c_l0[1] contains LX + set_flag(PIPE_M, PIPE_MTE1, event_1); // allow to load RX on b_l0[0] + + TMATMUL_ACC(c_l0_tile[0], c_l0_tile[0], a_l0_tile[1], + b_l0_tile[1]); // c_l0[0] <- LX * M_neg + I + set_flag(PIPE_M, PIPE_FIX, event_0); + set_flag(PIPE_M, PIPE_MTE1, event_0); + + wait_flag(PIPE_M, PIPE_FIX, event_0); + TMOV(Y_l1_tile, c_l0_tile[0]); // Y_l1 contains LX * M_neg + I + set_flag(PIPE_FIX, PIPE_MTE1, event_0); + set_flag(PIPE_FIX, PIPE_M, event_0); + + /* Load complementary blocks of X in L0B. If swap_parity = fase, "Load Odd + * Blocks Of X In L0B" */ + wait_flag(PIPE_M, PIPE_MTE1, event_1); + TMOV(b_l0_tile[0], Zero_l1_tile); + CopyOddOrEvenBlocksL1ToL0( + X_l1_tile, b_l0_tile[0], block_size, + swap_parity + ); // b_l0[0]: odd(RX) or even(LX) + + wait_flag(PIPE_M, PIPE_MTE1, event_0); // Wait for previous use of a_l0[1] + wait_flag(PIPE_FIX, PIPE_MTE1, event_0); // Wait for Y_l1 + TMOV(a_l0_tile[1], Y_l1_tile); // a_l0[1] contains LX * M_neg + I + set_flag(PIPE_MTE1, PIPE_M, event_0); + + wait_flag(PIPE_MTE1, PIPE_M, event_0); + TMATMUL_ACC(c_l0_tile[1], c_l0_tile[1], a_l0_tile[1], b_l0_tile[0]); + set_flag(PIPE_M, PIPE_MTE1, event_0); // next iter can read on a_l0[1] + set_flag(PIPE_M, PIPE_MTE1, event_1); // next iter can read on b_l0[0] + set_flag(PIPE_M, PIPE_FIX, event_0); + wait_flag(PIPE_M, PIPE_FIX, event_0); + + if (block_size < MatrixSize / 2) { // Update X_l1 except in last iteration + TMOV(X_l1_tile, c_l0_tile[1]); + set_flag(PIPE_FIX, PIPE_M, event_1); // release c_l0[1] for next iter + } + set_flag(PIPE_FIX, PIPE_MTE1, event_1); + } + wait_flag(PIPE_M, PIPE_MTE1, event_0); + wait_flag(PIPE_M, PIPE_MTE1, event_1); + wait_flag(PIPE_FIX, PIPE_M, event_0); + wait_flag(PIPE_FIX, PIPE_MTE1, event_1); // Write c_l0[1] to X_l1 } /** @@ -547,359 +524,316 @@ AICORE inline void InvertSingleTile(TileL1AB X_l1_tile, TileL1AB I_l1_tile, * @param M_inv pointer to the global memory to store the final inverse. * @param M Pointer to the global tensor matrix in global memory. * @param I_neg Pointer to global memory that contains the negative identity. - * @param total_tiles The total number of matrices to invert. + * @param num_matrices The total number of matrices to invert. * @param num_bsnd_heads The number of heads, only for BSND format. * @param is_lower If input matrices are lower-triangular (is_lower == 1) or * upper-triangular (is_lower == 0). Default is upper triangular. * @param num_bsnd_heads The number of heads, only for BSND format. */ -template -AICORE inline void TriInvRecUnrollKernel(__gm__ OutputT* M_inv, - __gm__ InputT* M, __gm__ InputT* I_neg, - uint32_t total_tiles, - uint32_t num_bsnd_heads = 0, - uint32_t is_lower = 0, - __gm__ int32_t* cu_seqlens = nullptr) { - /* Initializations */ - constexpr uint32_t TileLen = MatrixSize * MatrixSize; - constexpr uint32_t FractalSize = 16; // fractal size for half /bf16 - constexpr uint32_t NumFractalsRowWise = MatrixSize / FractalSize; - constexpr uint32_t NumL0Buffers = 2; - - if (get_block_idx() * NumTilesPerCubeIter >= total_tiles) { - return; - } - - using GlobalTileShapeIn = - TileShape2D; - using GlobalTileStridesIn = typename std::conditional< - !IsBSND, BaseShape2D, - Stride<1, 1, 1, -1, 1>>::type; - using GlobalTileIn = - GlobalTensor; - using GlobalTileDynamicShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; - using GlobalTileDynamicStride = Stride<1, 1, 1, DYNAMIC, 1>; - using GlobalTileDynamicIn = GlobalTensor; - using GlobalTileStridesINeg = - BaseShape2D; - using GlobalTileINeg = GlobalTensor; - - using GlobalTileShapeOut = - TileShape2D; - using GlobalTileStridesOut = typename std::conditional< - !IsBSND, BaseShape2D, - Stride<1, 1, 1, -1, 1>>::type; - using GlobalTileOut = GlobalTensor; - using GlobalTileDynamicOut = - GlobalTensor; - using TileL1AB = - Tile; - using TileL1ABDynamic = - Tile; - - // L0 Memory - using TileL0A = TileLeft; - using TileL0B = TileRight; - using TileL0C = TileAcc; - using TileL0CDynamic = - TileAcc; - - GlobalTileINeg I_neg_global_in(I_neg); - - TileL1AB X_l1_tile; - TileL1AB I_l1_tile; - TileL1AB I_neg_l1_tile; - TileL1AB M_neg_l1_tile; - TileL1AB Zero_l1_tile; - TileL1AB Y_l1_tile[NumTilesPerCubeIter]; - - TileL0A a_l0_tile[NumL0Buffers]; - TileL0B b_l0_tile[NumL0Buffers]; - TileL0C c_l0_tile[NumL0Buffers]; - - TASSIGN(I_l1_tile, 0x0); - TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); - TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); - TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); - TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); - for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { - TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); - } - - for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { - TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); - TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); - TASSIGN(c_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(float)); - } - TLOAD(I_neg_l1_tile, I_neg_global_in); - set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); - wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); - - PrepareAuxiliaryMatrices( - I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], - c_l0_tile[0]); - - const uint32_t max_iters_per_aic = - CeilDiv(total_tiles, (uint32_t)(NumTilesPerCubeIter * get_block_num())); - - /* Main iteration - Compute all tiles */ - uint32_t bsnd_tile_offsets[NumTilesPerCubeIter] = {0}; - uint32_t bsnd_tile_valid_sizes[NumTilesPerCubeIter] = {0}; - uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; - set_flag(PIPE_FIX, PIPE_M, - static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); - for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { - set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - } - for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { - const uint32_t global_index = - (cube_iter * get_block_num() + get_block_idx()) * NumTilesPerCubeIter; - if (global_index >= total_tiles) { - break; +template +AICORE inline void TriInvRecUnrollKernel( + __gm__ OutputT *M_inv, __gm__ InputT *M, __gm__ InputT *I_neg, uint32_t block_dim, uint32_t num_matrices, + uint32_t num_bsnd_heads = 0, uint32_t is_lower = 0, __gm__ int32_t *cu_seqlens = nullptr +) { + /* Initializations */ + constexpr uint32_t TileLen = MatrixSize * MatrixSize; + constexpr uint32_t FractalSize = 16; // fractal size for half /bf16 + constexpr uint32_t NumFractalsRowWise = MatrixSize / FractalSize; + constexpr uint32_t NumL0Buffers = 2; + + if (get_block_idx() * NumTilesPerCubeIter >= num_matrices) { + return; } - for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && - (global_index + tile_id < total_tiles); - ++tile_id) { - if constexpr (IsBSND) { - const uint32_t global_tile_id = global_index + tile_id; - if (cu_seqlens != nullptr) { - const BSNDVarlenTileInfo tile_info = - kernel_utils::GetBSNDVarlenTileInfoFromCuSeqlens( - global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens); - bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; - bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; - } else { - bsnd_tile_offsets[tile_id] = kernel_utils::GetBSNDFixedTileOffset( - global_tile_id, num_bsnd_heads, MatrixSize); - bsnd_tile_valid_sizes[tile_id] = MatrixSize; + + using GlobalTileShapeIn = TileShape2D; + using GlobalTileStridesIn = typename std::conditional< + !IsBSND, BaseShape2D, Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileIn = GlobalTensor; + using GlobalTileDynamicShape = Shape<1, 1, 1, DYNAMIC, DYNAMIC>; + using GlobalTileDynamicStride = Stride<1, 1, 1, DYNAMIC, 1>; + using GlobalTileDynamicIn = GlobalTensor; + using GlobalTileStridesINeg = BaseShape2D; + using GlobalTileINeg = GlobalTensor; + + using GlobalTileShapeOut = TileShape2D; + using GlobalTileStridesOut = typename std::conditional< + !IsBSND, BaseShape2D, Stride<1, 1, 1, -1, 1>>::type; + using GlobalTileOut = GlobalTensor; + using GlobalTileDynamicOut = GlobalTensor; + using TileL1AB = Tile< + TileType::Mat, InputT, MatrixSize, MatrixSize, BLayout::ColMajor, MatrixSize, MatrixSize, SLayout::RowMajor, + 512>; + using TileL1ABDynamic = Tile< + TileType::Mat, InputT, MatrixSize, MatrixSize, BLayout::ColMajor, DYNAMIC, DYNAMIC, SLayout::RowMajor, 512, + PadValue::Zero>; + + // L0 Memory + using TileL0A = TileLeft; + using TileL0B = TileRight; + using TileL0C = TileAcc; + using TileL0CDynamic = TileAcc; + + GlobalTileINeg I_neg_global_in(I_neg); + + TileL1AB X_l1_tile; + TileL1AB I_l1_tile; + TileL1AB I_neg_l1_tile; + TileL1AB M_neg_l1_tile; + TileL1AB Zero_l1_tile; + TileL1AB Y_l1_tile[NumTilesPerCubeIter]; + + TileL0A a_l0_tile[NumL0Buffers]; + TileL0B b_l0_tile[NumL0Buffers]; + TileL0C c_l0_tile[NumL0Buffers]; + + TASSIGN(I_l1_tile, 0x0); + TASSIGN(I_neg_l1_tile, 0x0 + TileLen * sizeof(InputT)); + TASSIGN(Zero_l1_tile, 0x0 + 2 * TileLen * sizeof(InputT)); + TASSIGN(M_neg_l1_tile, 0x0 + 3 * TileLen * sizeof(InputT)); + TASSIGN(X_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + TASSIGN(Y_l1_tile[tile_id], 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + } + + for (uint32_t buffer_num = 0; buffer_num < NumL0Buffers; ++buffer_num) { + TASSIGN(a_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(b_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(InputT)); + TASSIGN(c_l0_tile[buffer_num], 0x0 + buffer_num * TileLen * sizeof(float)); + } + TLOAD(I_neg_l1_tile, I_neg_global_in); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(0)); + + PrepareAuxiliaryMatrices( + I_neg_l1_tile, Zero_l1_tile, I_l1_tile, a_l0_tile[0], b_l0_tile[0], c_l0_tile[0] + ); + + const uint32_t max_iters_per_aic = CeilDiv(num_matrices, (uint32_t)(NumTilesPerCubeIter * block_dim)); + + /* Main iteration - Compute all tiles */ + uint32_t bsnd_tile_offsets[NumTilesPerCubeIter] = {0}; + uint32_t bsnd_tile_valid_sizes[NumTilesPerCubeIter] = {0}; + uint32_t next_tile_id_that_waits_for_pipe_fix_pipe_m = 0; + set_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + for (uint32_t cube_iter = 0; cube_iter < max_iters_per_aic; ++cube_iter) { + const uint32_t global_index = (cube_iter * block_dim + get_block_idx()) * NumTilesPerCubeIter; + if (global_index >= num_matrices) { + break; } - const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; - const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; - const int row_stride = static_cast(MatrixSize * num_bsnd_heads); - wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - if (valid_size < MatrixSize) { - TileL1ABDynamic Y_dyn_l1_tile(valid_size, valid_size); - TASSIGN(Y_dyn_l1_tile, - 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); - GlobalTileDynamicIn M_global_in_dyn( - M + bsnd_offset, - {1, 1, 1, static_cast(valid_size), - static_cast(valid_size)}, - {1, 1, 1, row_stride, 1}); - TLOAD(Y_dyn_l1_tile, M_global_in_dyn); - set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); - } else { - GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); - TLOAD(Y_l1_tile[tile_id], M_global_in); + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && (global_index + tile_id < num_matrices); + ++tile_id) { + if constexpr (IsBSND) { + const uint32_t global_tile_id = global_index + tile_id; + if (cu_seqlens != nullptr) { + const BSNDVarlenTileInfo tile_info = tri_inv_utils::GetBSNDVarlenTileInfoFromCuSeqlens( + global_tile_id, num_bsnd_heads, MatrixSize, cu_seqlens + ); + bsnd_tile_offsets[tile_id] = tile_info.bsnd_offset; + bsnd_tile_valid_sizes[tile_id] = tile_info.valid_size; + } else { + bsnd_tile_offsets[tile_id] = + tri_inv_utils::GetBSNDFixedTileOffset(global_tile_id, num_bsnd_heads, MatrixSize); + bsnd_tile_valid_sizes[tile_id] = MatrixSize; + } + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + if (valid_size < MatrixSize) { + TileL1ABDynamic Y_dyn_l1_tile(valid_size, valid_size); + TASSIGN(Y_dyn_l1_tile, 0x0 + (5 + tile_id) * TileLen * sizeof(InputT)); + GlobalTileDynamicIn M_global_in_dyn( + M + bsnd_offset, {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1} + ); + TLOAD(Y_dyn_l1_tile, M_global_in_dyn); + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + TFILLPAD(Y_dyn_l1_tile, Y_dyn_l1_tile); + } else { + GlobalTileIn M_global_in(M + bsnd_offset, {}, {row_stride}); + TLOAD(Y_l1_tile[tile_id], M_global_in); + } + } else { + GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + TLOAD(Y_l1_tile[tile_id], + M_global_in); // Copies NumTilesPerCubeIter tiles at once + } + set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); } - } else { - GlobalTileIn M_global_in(M + (global_index + tile_id) * TileLen); - wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - TLOAD(Y_l1_tile[tile_id], - M_global_in); // Copies NumTilesPerCubeIter tiles at once - } - set_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - } - constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; - for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && - (global_index + tile_id < total_tiles); - ++tile_id) { - // Wait for previous cube iter to write result - wait_flag(PIPE_FIX, PIPE_M, static_cast(tile_id)); - // Wait for loading new matrices from GM - wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); - - InvertSingleTile( - X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, - Y_l1_tile[tile_id], a_l0_tile, b_l0_tile, c_l0_tile, tile_id, - is_lower != 0); - - // Allow next cube_iter to proceed for this tile_id - set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - - /* Store result */ - if constexpr (IsBSND) { - const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; - const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; - const int row_stride = static_cast(MatrixSize * num_bsnd_heads); - if (valid_size < MatrixSize) { - TileL0CDynamic c_l0_tail_tile(valid_size, valid_size); - TASSIGN(c_l0_tail_tile, - 0x0 + final_c_buffer_index * TileLen * sizeof(float)); - GlobalTileDynamicOut M_inv_global_out_dyn( - M_inv + bsnd_offset, - {1, 1, 1, static_cast(valid_size), - static_cast(valid_size)}, - {1, 1, 1, row_stride, 1}); - TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); - } else { - GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + constexpr uint32_t final_c_buffer_index = MatrixSize > FractalSize ? 1 : 0; + for (uint32_t tile_id = 0; (tile_id < NumTilesPerCubeIter) && (global_index + tile_id < num_matrices); + ++tile_id) { + // Wait for previous cube iter to write result + wait_flag(PIPE_FIX, PIPE_M, static_cast(tile_id)); + // Wait for loading new matrices from GM + wait_flag(PIPE_MTE2, PIPE_MTE1, static_cast(tile_id)); + + InvertSingleTile( + X_l1_tile, I_l1_tile, I_neg_l1_tile, M_neg_l1_tile, Zero_l1_tile, Y_l1_tile[tile_id], a_l0_tile, + b_l0_tile, c_l0_tile, tile_id, is_lower != 0 + ); + + // Allow next cube_iter to proceed for this tile_id + set_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + + /* Store result */ + if constexpr (IsBSND) { + const uint32_t bsnd_offset = bsnd_tile_offsets[tile_id]; + const uint32_t valid_size = bsnd_tile_valid_sizes[tile_id]; + const int row_stride = static_cast(MatrixSize * num_bsnd_heads); + if (valid_size < MatrixSize) { + TileL0CDynamic c_l0_tail_tile(valid_size, valid_size); + TASSIGN(c_l0_tail_tile, 0x0 + final_c_buffer_index * TileLen * sizeof(float)); + GlobalTileDynamicOut M_inv_global_out_dyn( + M_inv + bsnd_offset, {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, + {1, 1, 1, row_stride, 1} + ); + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + } else { + GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + } else { + GlobalTileOut M_inv_global_out(M_inv + (global_index + tile_id) * TileLen); + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + } + next_tile_id_that_waits_for_pipe_fix_pipe_m = (tile_id + 1) % NumTilesPerCubeIter; + set_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); } - } else { - GlobalTileOut M_inv_global_out(M_inv + - (global_index + tile_id) * TileLen); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); - } - next_tile_id_that_waits_for_pipe_fix_pipe_m = - (tile_id + 1) % NumTilesPerCubeIter; - set_flag( - PIPE_FIX, PIPE_M, - static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); } - } - for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { - wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); - } - wait_flag(PIPE_FIX, PIPE_M, - static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); + for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + } + wait_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); } /* * @brief: Computes the inverses of the blocks of tensor M */ -template -AICORE void runKernelTriInvRecUnroll(__gm__ OutputT* M_inv, __gm__ InputT* M, - __gm__ InputT* I_neg, uint32_t total_tiles, - uint32_t num_bsnd_heads = 0, - uint32_t is_lower = 0, - __gm__ int32_t* cu_seqlens = nullptr) { +template +AICORE void runKernelTriInvRecUnroll( + __gm__ OutputT *M_inv, __gm__ InputT *M, __gm__ InputT *I_neg, uint32_t block_dim, uint32_t num_matrices, + uint32_t num_bsnd_heads = 0, uint32_t is_lower = 0, __gm__ int32_t *cu_seqlens = nullptr +) { #if defined(__DAV_CUBE__) // Cube compilation - TriInvRecUnrollKernel(M_inv, M, I_neg, total_tiles, num_bsnd_heads, - is_lower, cu_seqlens); + TriInvRecUnrollKernel( + M_inv, M, I_neg, block_dim, num_matrices, num_bsnd_heads, is_lower, cu_seqlens + ); #else // Nothing to do on AIV #endif } -template -AICORE void run_tri_inv_rec_unroll(__gm__ OutputT* tensor_out, - __gm__ InputT* tensor_in, - __gm__ InputT* minus_eye_in, - uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, - uint32_t is_lower = 0, - __gm__ int32_t* cu_seqlens = nullptr) { - static_assert( - std::is_same_v or std::is_same_v, - "tri_inv_rec_unroll supports only fp16 or bf16."); - - static_assert( - std::is_same_v or std::is_same_v, - "tri_inv_rec_unroll supports only fp16 or bf16."); - switch (matrix_size) { +template +AICORE void run_tri_inv_rec_unroll( + __gm__ OutputT *tensor_out, __gm__ InputT *tensor_in, __gm__ InputT *minus_eye_in, uint32_t block_dim, + uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, uint32_t is_lower = 0, + __gm__ int32_t *cu_seqlens = nullptr +) { + static_assert( + std::is_same_v or std::is_same_v, + "tri_inv_rec_unroll supports only fp16 or bf16." + ); + + static_assert( + std::is_same_v or std::is_same_v, + "tri_inv_rec_unroll supports only fp16 or bf16." + ); + switch (matrix_size) { case 16: - runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, - num_matrices, num_bsnd_heads, is_lower, - cu_seqlens); - break; + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_eye_in, block_dim, num_matrices, num_bsnd_heads, is_lower, cu_seqlens + ); + break; case 32: - runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, - num_matrices, num_bsnd_heads, is_lower, - cu_seqlens); - break; + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_eye_in, block_dim, num_matrices, num_bsnd_heads, is_lower, cu_seqlens + ); + break; case 64: - runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, - num_matrices, num_bsnd_heads, is_lower, - cu_seqlens); - break; + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_eye_in, block_dim, num_matrices, num_bsnd_heads, is_lower, cu_seqlens + ); + break; case 128: - runKernelTriInvRecUnroll(tensor_out, tensor_in, minus_eye_in, - num_matrices, num_bsnd_heads, is_lower, - cu_seqlens); - break; - } + runKernelTriInvRecUnroll( + tensor_out, tensor_in, minus_eye_in, block_dim, num_matrices, num_bsnd_heads, is_lower, cu_seqlens + ); + break; + } } -template +template AICORE void run_tri_inv_rec_unroll_per_num_matrices( - __gm__ OutputT* tensor_out, __gm__ InputT* tensor_in, - __gm__ InputT* minus_eye_in, uint32_t matrix_size, uint32_t num_matrices, - uint32_t num_bsnd_heads, uint32_t is_lower = 0, - __gm__ int32_t* cu_seqlens = nullptr) { - if (num_bsnd_heads == 0) { - if (num_matrices <= get_block_num()) { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); - } else if (num_matrices <= 2 * get_block_num()) { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); - } else { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); - } - } else { - if (num_matrices <= get_block_num()) { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); - } else if (num_matrices <= 2 * get_block_num()) { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); + __gm__ OutputT *tensor_out, __gm__ InputT *tensor_in, __gm__ InputT *minus_eye_in, uint32_t block_dim, + uint32_t matrix_size, uint32_t num_matrices, uint32_t num_bsnd_heads, uint32_t is_lower = 0, + __gm__ int32_t *cu_seqlens = nullptr +) { + if (num_bsnd_heads == 0) { + if (num_matrices <= block_dim) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } else if (num_matrices <= 2 * block_dim) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } else { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } } else { - run_tri_inv_rec_unroll( - tensor_out, tensor_in, minus_eye_in, matrix_size, num_matrices, - num_bsnd_heads, is_lower, cu_seqlens); + if (num_matrices <= block_dim) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } else if (num_matrices <= 2 * block_dim) { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } else { + run_tri_inv_rec_unroll( + tensor_out, tensor_in, minus_eye_in, block_dim, matrix_size, num_matrices, num_bsnd_heads, is_lower, + cu_seqlens + ); + } } - } } // --------------------------------------------------------------------------- // simpler framework entry point // --------------------------------------------------------------------------- -extern "C" __aicore__ __attribute__((always_inline)) void -kernel_entry(__gm__ int64_t* args) { - __gm__ Tensor* M_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); - __gm__ Tensor* I_neg_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); - __gm__ Tensor* M_inv_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); - __gm__ Tensor* config_tensor = reinterpret_cast<__gm__ Tensor*>(args[3]); - - __gm__ int64_t* cfg = - reinterpret_cast<__gm__ int64_t*>(config_tensor->buffer.addr); - const uint32_t matrix_size = static_cast(cfg[0]); - const uint32_t num_matrices = static_cast(cfg[1]); - const uint32_t is_lower = static_cast(cfg[2]); - - __gm__ half* M = reinterpret_cast<__gm__ half*>(M_tensor->buffer.addr) + - M_tensor->start_offset; - __gm__ half* I_neg = - reinterpret_cast<__gm__ half*>(I_neg_tensor->buffer.addr) + - I_neg_tensor->start_offset; - __gm__ half* M_inv = - reinterpret_cast<__gm__ half*>(M_inv_tensor->buffer.addr) + - M_inv_tensor->start_offset; - - run_tri_inv_rec_unroll_per_num_matrices( - M_inv, M, I_neg, matrix_size, num_matrices, 0, is_lower, nullptr); +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *M_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *I_neg_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *M_inv_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *config_tensor = reinterpret_cast<__gm__ Tensor *>(args[3]); + + __gm__ int64_t *cfg = reinterpret_cast<__gm__ int64_t *>(config_tensor->buffer.addr); + const uint32_t matrix_size = static_cast(cfg[0]); + const uint32_t num_matrices = static_cast(cfg[1]); + const uint32_t is_lower = static_cast(cfg[2]); + const uint32_t block_dim = static_cast(cfg[3]); + + __gm__ half *M = reinterpret_cast<__gm__ half *>(M_tensor->buffer.addr) + M_tensor->start_offset; + __gm__ half *I_neg = reinterpret_cast<__gm__ half *>(I_neg_tensor->buffer.addr) + I_neg_tensor->start_offset; + __gm__ half *M_inv = reinterpret_cast<__gm__ half *>(M_inv_tensor->buffer.addr) + M_inv_tensor->start_offset; + + run_tri_inv_rec_unroll_per_num_matrices( + M_inv, M, I_neg, block_dim, matrix_size, num_matrices, 0, is_lower, nullptr + ); } diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp index f02ddbfae..f0c91b46f 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp @@ -17,7 +17,7 @@ * tensor(0) = M (INPUT) fp16 triangular matrices [num_matrices * N * N] * tensor(1) = I_neg (INPUT) fp16 negative identity [N * N] * tensor(2) = M_inv (OUTPUT) fp16 result [num_matrices * N * N] - * tensor(3) = config (INPUT) int64[3]: [matrix_size, num_matrices, is_lower] + * tensor(3) = config (INPUT) int64[3]: [matrix_size, num_matrices, is_lower, block_dim] * * The single AIC task (func_id=0) receives these four args in the same order * and dispatches to run_tri_inv_rec_unroll_per_num_matrices. @@ -33,28 +33,28 @@ extern "C" { __attribute__((visibility("default"))) PTO2OrchestrationConfig -aicpu_orchestration_config(const ChipStorageTaskArgs& orch_args) { +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { (void)orch_args; return PTO2OrchestrationConfig{ .expected_arg_count = 4, }; } -__attribute__((visibility("default"))) void -aicpu_orchestration_entry(const ChipStorageTaskArgs& orch_args) { - Tensor ext_M = from_tensor_arg(orch_args.tensor(0)); - Tensor ext_I_neg = from_tensor_arg(orch_args.tensor(1)); - Tensor ext_M_inv = from_tensor_arg(orch_args.tensor(2)); +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + Tensor ext_M = from_tensor_arg(orch_args.tensor(0)); + Tensor ext_I_neg = from_tensor_arg(orch_args.tensor(1)); + Tensor ext_M_inv = from_tensor_arg(orch_args.tensor(2)); Tensor ext_config = from_tensor_arg(orch_args.tensor(3)); - int64_t* host_config = orch_args.tensor(3).data_as(); - int matrix_size = static_cast(host_config[0]); + int64_t *host_config = orch_args.tensor(3).data_as(); + int matrix_size = static_cast(host_config[0]); int num_matrices = static_cast(host_config[1]); - int is_lower = static_cast(host_config[2]); + int is_lower = static_cast(host_config[2]); + int block_dim = static_cast(host_config[3]); LOG_INFO_V0( - "[tri_inv_orch] matrix_size: %d, num_matrices: %d, is_lower: %d", - matrix_size, num_matrices, is_lower + "[tri_inv_orch] matrix_size: %d, num_matrices: %d, is_lower: %d, block_dim: %d", matrix_size, num_matrices, + is_lower, block_dim ); Arg params; diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 1bdce9c04..80cb85571 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -7,13 +7,37 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""Triangular inverse (recursive unrolled): M_inv = inv(M) for upper/lower-triangular fp16 matrices.""" +"""Triangular inverse (recursive unrolled): M_inv = inv(M) for upper/lower-triangular unit diagonal matrices.""" import torch +import numpy as np from simpler.task_interface import ArgDirection as D from simpler_setup import SceneTestCase, TaskArgsBuilder, Tensor, scene_test +def random_tri_matrix(n, block_dim_x, block_dim_y, scale=0.1, is_lower=False): + if is_lower: + return scale * torch.tril( + torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=-1 + ) + else: + return scale * torch.triu( + torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1 + ) + +def linalg_inv(A: torch.tensor) -> torch.tensor: + assert A.ndim == 4, "Expected 4D tensor" + assert A.shape[-2] == A.shape[-1], "Expected square matrices on last two dimensions" + in_dtype = A.dtype + n = A.shape[-1] + Identity = np.eye(n, dtype=np.double) + golden_numpy = np.zeros((A.shape)) + for x in range(A.shape[0]): + for y in range(A.shape[1]): + golden_numpy[x, y] = np.linalg.inv( + A[x, y].double().numpy().astype(np.double) + Identity + ) + return torch.from_numpy(golden_numpy).to(in_dtype) @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestTriangularInverse(SceneTestCase): @@ -40,55 +64,44 @@ class TestTriangularInverse(SceneTestCase): CASES = [ { - "name": "Case0_upper32", + "name": "Case_upper_tri_matrix_size_32", "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"num_matrices": 4, "matrix_size": 32, "is_lower": 0}, + "params": {"num_matrices": 1, "matrix_size": 32, "is_lower": 0}, }, { - "name": "Case1_upper64", + "name": "Case_upper_tri_matrix_size_64", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"num_matrices": 4, "matrix_size": 64, "is_lower": 0}, + "params": {"num_matrices": 16, "matrix_size": 64, "is_lower": 0}, }, { - "name": "Case2_lower32", + "name": "Case_lower_tri_matrix_size_32", "manual": True, "platforms": ["a2a3sim", "a2a3"], "config": {"aicpu_thread_num": 4, "block_dim": 16}, - "params": {"num_matrices": 4, "matrix_size": 32, "is_lower": 1}, + "params": {"num_matrices": 16, "matrix_size": 32, "is_lower": 1}, }, ] def generate_args(self, params): n = params["matrix_size"] num_matrices = params["num_matrices"] + block_dim = min(num_matrices, 20) is_lower = params["is_lower"] # Build well-conditioned triangular matrices in fp16. # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. - M_f32 = torch.rand(num_matrices, n, n, dtype=torch.float32) - if is_lower: - M_f32 = torch.tril(M_f32) - else: - M_f32 = torch.triu(M_f32) - # Diagonal entries in [0.5, 1.5] — keeps condition number reasonable - diag_vals = torch.rand(num_matrices, n, dtype=torch.float32) + 0.5 - idx = torch.arange(n) - M_f32[:, idx, idx] = diag_vals - M = M_f32.to(torch.float16) - - # Negative identity matrix — used by the kernel to derive I and Zero - I_neg = (-torch.eye(n, dtype=torch.float16)).flatten() - - M_inv = torch.zeros(num_matrices, n, n, dtype=torch.float16) - config = torch.tensor([n, num_matrices, is_lower], dtype=torch.int64) + M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16) + I_neg = -torch.eye(n, dtype=torch.float16) + M_inv = torch.zeros((num_matrices, 1, n , n), dtype=torch.float16) + config = torch.tensor([n, num_matrices, is_lower, block_dim], dtype=torch.int64) return TaskArgsBuilder( - Tensor("M", M.flatten()), - Tensor("I_neg", I_neg), + Tensor("M", M_fp16.flatten()), + Tensor("I_neg", I_neg.flatten()), Tensor("M_inv", M_inv.flatten()), Tensor("config", config), ) @@ -96,9 +109,9 @@ def generate_args(self, params): def compute_golden(self, args, params): n = params["matrix_size"] num_matrices = params["num_matrices"] - M = args.M.reshape(num_matrices, n, n).to(torch.float32) - # Compute in float32 for numerical stability, then cast to fp16 - M_inv = torch.linalg.inv(M).to(torch.float16) + M = args.M.reshape(num_matrices,1, n, n).to(torch.float16) + M_inv = linalg_inv(M) + print("M_inv (golden):", M_inv) return M_inv From ece47e36271b4f285a4000600e248f4c8a6a9d0f Mon Sep 17 00:00:00 2001 From: anastasios Date: Wed, 20 May 2026 06:29:11 +0000 Subject: [PATCH 07/16] fix --- .../test_triangular_inverse.py | 23 +++++++++++++------ simpler_setup/scene_test.py | 2 ++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 80cb85571..16e4f37b5 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -15,6 +15,7 @@ from simpler_setup import SceneTestCase, TaskArgsBuilder, Tensor, scene_test + def random_tri_matrix(n, block_dim_x, block_dim_y, scale=0.1, is_lower=False): if is_lower: return scale * torch.tril( @@ -25,6 +26,7 @@ def random_tri_matrix(n, block_dim_x, block_dim_y, scale=0.1, is_lower=False): torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1 ) + def linalg_inv(A: torch.tensor) -> torch.tensor: assert A.ndim == 4, "Expected 4D tensor" assert A.shape[-2] == A.shape[-1], "Expected square matrices on last two dimensions" @@ -39,10 +41,11 @@ def linalg_inv(A: torch.tensor) -> torch.tensor: ) return torch.from_numpy(golden_numpy).to(in_dtype) + @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestTriangularInverse(SceneTestCase): # fp16 arithmetic — use tolerances appropriate for half-precision results - RTOL = 1e-2 + RTOL = 1e-5 ATOL = 1e-2 CALLABLE = { @@ -66,7 +69,7 @@ class TestTriangularInverse(SceneTestCase): { "name": "Case_upper_tri_matrix_size_32", "platforms": ["a2a3sim", "a2a3"], - "config": {"aicpu_thread_num": 4, "block_dim": 16}, + "config": {"aicpu_thread_num": 4, "block_dim": 1}, "params": {"num_matrices": 1, "matrix_size": 32, "is_lower": 0}, }, { @@ -94,9 +97,15 @@ def generate_args(self, params): # Build well-conditioned triangular matrices in fp16. # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. - M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16) - I_neg = -torch.eye(n, dtype=torch.float16) - M_inv = torch.zeros((num_matrices, 1, n , n), dtype=torch.float16) + M_fp16 = ( + random_tri_matrix(n, 1, num_matrices, is_lower=is_lower) + .to(torch.float16) + .contiguous() + ) + I_neg = -torch.eye( + n, dtype=torch.float16 + ).contiguous() # Identity matrix for use in the kernel, negated to match the kernel's expected input + M_inv = torch.zeros((num_matrices, 1, n, n), dtype=torch.float16) config = torch.tensor([n, num_matrices, is_lower, block_dim], dtype=torch.int64) return TaskArgsBuilder( @@ -109,9 +118,9 @@ def generate_args(self, params): def compute_golden(self, args, params): n = params["matrix_size"] num_matrices = params["num_matrices"] - M = args.M.reshape(num_matrices,1, n, n).to(torch.float16) + M = args.M.reshape(num_matrices, 1, n, n).to(torch.float16) M_inv = linalg_inv(M) - print("M_inv (golden):", M_inv) + print("M_inv (golden):", M_inv.flatten()[:10]) return M_inv diff --git a/simpler_setup/scene_test.py b/simpler_setup/scene_test.py index 5a994b332..c96ec7aef 100644 --- a/simpler_setup/scene_test.py +++ b/simpler_setup/scene_test.py @@ -693,6 +693,8 @@ def _compare_outputs(test_args, golden_args, output_names, rtol, atol): for name in output_names: actual = getattr(test_args, name) expected = getattr(golden_args, name) + print("actual: ", actual[:10]) + print("diff: ", (actual - expected).abs().mean().item()) if not torch.allclose(actual, expected, rtol=rtol, atol=atol): diff = (actual - expected).abs().max().item() raise AssertionError(f"Golden mismatch on '{name}': max_diff={diff}, rtol={rtol}, atol={atol}") From 7f69bead13f1d0ec657862215ccf1a3dc40d0e95 Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 20 May 2026 13:34:42 +0200 Subject: [PATCH 08/16] Fix triangular inverse a2a3sim compile and golden check --- .../kernels/aic/kernel_tri_inv_rec_unroll.cpp | 11 ++++++++--- .../test_triangular_inverse.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp index 19471ccfb..8f7f66992 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp @@ -694,14 +694,19 @@ AICORE inline void TriInvRecUnrollKernel( M_inv + bsnd_offset, {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, {1, 1, 1, row_stride, 1} ); - TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); + TileL1ABDynamic X_dyn_l1_tile(valid_size, valid_size); + TASSIGN(X_dyn_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); + TMOV(X_dyn_l1_tile, c_l0_tail_tile); + TSTORE(M_inv_global_out_dyn, X_dyn_l1_tile); } else { GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + TMOV(X_l1_tile, c_l0_tile[final_c_buffer_index]); + TSTORE(M_inv_global_out, X_l1_tile); } } else { GlobalTileOut M_inv_global_out(M_inv + (global_index + tile_id) * TileLen); - TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); + TMOV(X_l1_tile, c_l0_tile[final_c_buffer_index]); + TSTORE(M_inv_global_out, X_l1_tile); } next_tile_id_that_waits_for_pipe_fix_pipe_m = (tile_id + 1) % NumTilesPerCubeIter; set_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 16e4f37b5..8b52f05db 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -121,7 +121,7 @@ def compute_golden(self, args, params): M = args.M.reshape(num_matrices, 1, n, n).to(torch.float16) M_inv = linalg_inv(M) print("M_inv (golden):", M_inv.flatten()[:10]) - return M_inv + args.M_inv.copy_(M_inv.flatten()) if __name__ == "__main__": From 7f0694de5008bcef945cca37af9f34e62203705f Mon Sep 17 00:00:00 2001 From: mirkodevita Date: Wed, 20 May 2026 16:12:35 +0200 Subject: [PATCH 09/16] fix compile error for a2a3 mode --- .../kernels/aic/kernel_tri_inv_rec_unroll.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp index 8f7f66992..e439ccf11 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp @@ -694,19 +694,31 @@ AICORE inline void TriInvRecUnrollKernel( M_inv + bsnd_offset, {1, 1, 1, static_cast(valid_size), static_cast(valid_size)}, {1, 1, 1, row_stride, 1} ); +#ifdef __CPU_SIM TileL1ABDynamic X_dyn_l1_tile(valid_size, valid_size); TASSIGN(X_dyn_l1_tile, 0x0 + 4 * TileLen * sizeof(InputT)); TMOV(X_dyn_l1_tile, c_l0_tail_tile); TSTORE(M_inv_global_out_dyn, X_dyn_l1_tile); +#else + TSTORE(M_inv_global_out_dyn, c_l0_tail_tile); +#endif } else { GlobalTileOut M_inv_global_out(M_inv + bsnd_offset, {}, {row_stride}); +#ifdef __CPU_SIM TMOV(X_l1_tile, c_l0_tile[final_c_buffer_index]); TSTORE(M_inv_global_out, X_l1_tile); +#else + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); +#endif } } else { GlobalTileOut M_inv_global_out(M_inv + (global_index + tile_id) * TileLen); +#ifdef __CPU_SIM TMOV(X_l1_tile, c_l0_tile[final_c_buffer_index]); TSTORE(M_inv_global_out, X_l1_tile); +#else + TSTORE(M_inv_global_out, c_l0_tile[final_c_buffer_index]); +#endif } next_tile_id_that_waits_for_pipe_fix_pipe_m = (tile_id + 1) % NumTilesPerCubeIter; set_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); From 3ad4a3293567db205881968c31ac1abe86c55a2c Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 13:38:13 +0000 Subject: [PATCH 10/16] (examples) revert bgemm changes --- .../benchmark_bgemm/test_benchmark_bgemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py b/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py index d8b073922..a3b888f75 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py +++ b/examples/a2a3/tensormap_and_ringbuffer/benchmark_bgemm/test_benchmark_bgemm.py @@ -17,8 +17,8 @@ @scene_test(level=2, runtime="tensormap_and_ringbuffer") class TestBenchmarkBgemm(SceneTestCase): - RTOL = 1e-5 - ATOL = 1e-5 + RTOL = 1e-3 + ATOL = 1e-3 CALLABLE = { "orchestration": { @@ -32,14 +32,14 @@ class TestBenchmarkBgemm(SceneTestCase): "name": "GEMM", "source": "kernels/aic/kernel_gemm_tile.cpp", "core_type": "aic", - "signature": [D.IN, D.IN, D.OUT, D.IN], + "signature": [D.IN, D.IN, D.OUT], }, { "func_id": 1, "name": "ADD", "source": "kernels/aiv/kernel_tile_add.cpp", "core_type": "aiv", - "signature": [D.INOUT, D.IN, D.IN], + "signature": [D.INOUT, D.IN], }, ], } From 74e97a49f10eaed6c4b01327407025e9a3d70a26 Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 13:41:31 +0000 Subject: [PATCH 11/16] remove Makefile --- Makefile | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 Makefile diff --git a/Makefile b/Makefile deleted file mode 100644 index 2f854df67..000000000 --- a/Makefile +++ /dev/null @@ -1,8 +0,0 @@ - - -all: - python examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py -p a2a3sim - - -run_on_npu: - python examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py -p a2a3 From c13bc266f42dfb50ac1296261af51911db7d5ed5 Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 13:48:23 +0000 Subject: [PATCH 12/16] remove prints --- .../triangular_inverse_example/test_triangular_inverse.py | 1 - simpler_setup/scene_test.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 8b52f05db..b85dc702f 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -120,7 +120,6 @@ def compute_golden(self, args, params): num_matrices = params["num_matrices"] M = args.M.reshape(num_matrices, 1, n, n).to(torch.float16) M_inv = linalg_inv(M) - print("M_inv (golden):", M_inv.flatten()[:10]) args.M_inv.copy_(M_inv.flatten()) diff --git a/simpler_setup/scene_test.py b/simpler_setup/scene_test.py index c96ec7aef..5a994b332 100644 --- a/simpler_setup/scene_test.py +++ b/simpler_setup/scene_test.py @@ -693,8 +693,6 @@ def _compare_outputs(test_args, golden_args, output_names, rtol, atol): for name in output_names: actual = getattr(test_args, name) expected = getattr(golden_args, name) - print("actual: ", actual[:10]) - print("diff: ", (actual - expected).abs().mean().item()) if not torch.allclose(actual, expected, rtol=rtol, atol=atol): diff = (actual - expected).abs().max().item() raise AssertionError(f"Golden mismatch on '{name}': max_diff={diff}, rtol={rtol}, atol={atol}") From 26090afdad5119f5cfd629c71bfe35dce726172a Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 14:11:04 +0000 Subject: [PATCH 13/16] make ruff happy --- .../test_triangular_inverse.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index b85dc702f..f7d996a42 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -18,13 +18,9 @@ def random_tri_matrix(n, block_dim_x, block_dim_y, scale=0.1, is_lower=False): if is_lower: - return scale * torch.tril( - torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=-1 - ) + return scale * torch.tril(torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=-1) else: - return scale * torch.triu( - torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1 - ) + return scale * torch.triu(torch.rand((block_dim_x, block_dim_y, n, n)), diagonal=1) def linalg_inv(A: torch.tensor) -> torch.tensor: @@ -36,9 +32,7 @@ def linalg_inv(A: torch.tensor) -> torch.tensor: golden_numpy = np.zeros((A.shape)) for x in range(A.shape[0]): for y in range(A.shape[1]): - golden_numpy[x, y] = np.linalg.inv( - A[x, y].double().numpy().astype(np.double) + Identity - ) + golden_numpy[x, y] = np.linalg.inv(A[x, y].double().numpy().astype(np.double) + Identity) return torch.from_numpy(golden_numpy).to(in_dtype) @@ -97,11 +91,7 @@ def generate_args(self, params): # Build well-conditioned triangular matrices in fp16. # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. - M_fp16 = ( - random_tri_matrix(n, 1, num_matrices, is_lower=is_lower) - .to(torch.float16) - .contiguous() - ) + M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16).contiguous() I_neg = -torch.eye( n, dtype=torch.float16 ).contiguous() # Identity matrix for use in the kernel, negated to match the kernel's expected input From b480a6d92f0c350e2e2ba64e670726d050997ebc Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 14:13:46 +0000 Subject: [PATCH 14/16] make linters happy --- .../kernels/aic/kernel_tri_inv_rec_unroll.cpp | 11 +++++------ .../test_triangular_inverse.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp index e439ccf11..b642bc86d 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/aic/kernel_tri_inv_rec_unroll.cpp @@ -31,10 +31,6 @@ for the full License text. using namespace pto; -#ifndef GM_ADDR -#define GM_ADDR __gm__ uint8_t * -#endif - #ifndef __gm__ #define __gm__ #endif @@ -47,6 +43,7 @@ namespace tri_inv_utils { template AICORE inline void SetWaitFlag(uint32_t id) { + using pto::event_t; set_flag(SrcPipe, DstPipe, static_cast(id)); wait_flag(SrcPipe, DstPipe, static_cast(id)); } @@ -91,7 +88,6 @@ AICORE inline BSNDVarlenTileInfo GetBSNDVarlenTileInfoFromCuSeqlens( seq_start = seq_end; } } - } // namespace tri_inv_utils using namespace tri_inv_utils; @@ -212,6 +208,7 @@ AICORE inline void PrepareAuxiliaryMatrices( TileL1AB I_neg_l1_tile, TileL1AB Zero_l1_tile, TileL1AB I_l1_tile, TileL0A a_l0_tile, TileL0B b_l0_tile, TileL0C c_l0_tile ) { + using pto::event_t; TMOV(a_l0_tile, I_neg_l1_tile); // a_l0 initialized with I_neg TMOV(b_l0_tile, I_neg_l1_tile); // b_l0 initialized with I_neg set_flag(PIPE_MTE1, PIPE_M, static_cast(0)); @@ -278,6 +275,7 @@ AICORE inline void InvertSingleTile( TileL1AB Y_l1_tile, TileL0A *a_l0_tile, TileL0B *b_l0_tile, TileL0C *c_l0_tile, const uint32_t tile_id, const bool swap_parity = false ) { + using pto::event_t; const event_t event_0 = static_cast(tile_id); const event_t event_1 = static_cast(tile_id + NumTilesPerCubeIter); @@ -535,6 +533,7 @@ AICORE inline void TriInvRecUnrollKernel( __gm__ OutputT *M_inv, __gm__ InputT *M, __gm__ InputT *I_neg, uint32_t block_dim, uint32_t num_matrices, uint32_t num_bsnd_heads = 0, uint32_t is_lower = 0, __gm__ int32_t *cu_seqlens = nullptr ) { + using pto::event_t; /* Initializations */ constexpr uint32_t TileLen = MatrixSize * MatrixSize; constexpr uint32_t FractalSize = 16; // fractal size for half /bf16 @@ -725,7 +724,7 @@ AICORE inline void TriInvRecUnrollKernel( } } for (uint32_t tile_id = 0; tile_id < NumTilesPerCubeIter; ++tile_id) { - wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); + wait_flag(PIPE_M, PIPE_MTE2, static_cast(tile_id)); } wait_flag(PIPE_FIX, PIPE_M, static_cast(next_tile_id_that_waits_for_pipe_fix_pipe_m)); } diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index f7d996a42..f900f4121 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -29,7 +29,7 @@ def linalg_inv(A: torch.tensor) -> torch.tensor: in_dtype = A.dtype n = A.shape[-1] Identity = np.eye(n, dtype=np.double) - golden_numpy = np.zeros((A.shape)) + golden_numpy = np.zeros(A.shape) for x in range(A.shape[0]): for y in range(A.shape[1]): golden_numpy[x, y] = np.linalg.inv(A[x, y].double().numpy().astype(np.double) + Identity) @@ -91,10 +91,10 @@ def generate_args(self, params): # Build well-conditioned triangular matrices in fp16. # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. - M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16).contiguous() + M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16) I_neg = -torch.eye( n, dtype=torch.float16 - ).contiguous() # Identity matrix for use in the kernel, negated to match the kernel's expected input + ) M_inv = torch.zeros((num_matrices, 1, n, n), dtype=torch.float16) config = torch.tensor([n, num_matrices, is_lower, block_dim], dtype=torch.int64) From 9bbf06e45c2dc86eabf20df3a4fa938d5ed21302 Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 14:26:46 +0000 Subject: [PATCH 15/16] (tri_inv) improve compute_golden --- .../test_triangular_inverse.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index f900f4121..3a21ed961 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -83,7 +83,7 @@ class TestTriangularInverse(SceneTestCase): ] def generate_args(self, params): - n = params["matrix_size"] + matrix_size = params["matrix_size"] num_matrices = params["num_matrices"] block_dim = min(num_matrices, 20) is_lower = params["is_lower"] @@ -91,12 +91,12 @@ def generate_args(self, params): # Build well-conditioned triangular matrices in fp16. # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. - M_fp16 = random_tri_matrix(n, 1, num_matrices, is_lower=is_lower).to(torch.float16) + M_fp16 = random_tri_matrix(matrix_size, 1, num_matrices, is_lower=is_lower).to(torch.float16) I_neg = -torch.eye( - n, dtype=torch.float16 + matrix_size, dtype=torch.float16 ) - M_inv = torch.zeros((num_matrices, 1, n, n), dtype=torch.float16) - config = torch.tensor([n, num_matrices, is_lower, block_dim], dtype=torch.int64) + M_inv = torch.randn((num_matrices, 1, matrix_size, matrix_size), dtype=torch.float16) + config = torch.tensor([matrix_size, num_matrices, is_lower, block_dim], dtype=torch.int64) return TaskArgsBuilder( Tensor("M", M_fp16.flatten()), @@ -108,9 +108,9 @@ def generate_args(self, params): def compute_golden(self, args, params): n = params["matrix_size"] num_matrices = params["num_matrices"] - M = args.M.reshape(num_matrices, 1, n, n).to(torch.float16) - M_inv = linalg_inv(M) - args.M_inv.copy_(M_inv.flatten()) + M = args.M.reshape(1, num_matrices, n, n) + M_inv = args.M_inv.reshape(1, num_matrices, n, n) + M_inv[:] = linalg_inv(M) if __name__ == "__main__": From dd7f332c0a8dc542f2e799988ac5b5b7e6eb2ffc Mon Sep 17 00:00:00 2001 From: anastasios Date: Fri, 22 May 2026 16:57:36 +0000 Subject: [PATCH 16/16] make ruff happy --- .../kernels/orchestration/triangular_inverse_orch.cpp | 2 +- .../triangular_inverse_example/test_triangular_inverse.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp index f0c91b46f..f8b57e817 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/kernels/orchestration/triangular_inverse_orch.cpp @@ -17,7 +17,7 @@ * tensor(0) = M (INPUT) fp16 triangular matrices [num_matrices * N * N] * tensor(1) = I_neg (INPUT) fp16 negative identity [N * N] * tensor(2) = M_inv (OUTPUT) fp16 result [num_matrices * N * N] - * tensor(3) = config (INPUT) int64[3]: [matrix_size, num_matrices, is_lower, block_dim] + * tensor(4) = config (INPUT) int64[4]: [matrix_size, num_matrices, is_lower, block_dim] * * The single AIC task (func_id=0) receives these four args in the same order * and dispatches to run_tri_inv_rec_unroll_per_num_matrices. diff --git a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py index 3a21ed961..e6a27a663 100644 --- a/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py +++ b/examples/a2a3/tensormap_and_ringbuffer/triangular_inverse_example/test_triangular_inverse.py @@ -92,9 +92,7 @@ def generate_args(self, params): # Start with random values and zero out the off-triangle, then set # the diagonal to a value in [0.5, 1.5] to ensure invertibility. M_fp16 = random_tri_matrix(matrix_size, 1, num_matrices, is_lower=is_lower).to(torch.float16) - I_neg = -torch.eye( - matrix_size, dtype=torch.float16 - ) + I_neg = -torch.eye(matrix_size, dtype=torch.float16) M_inv = torch.randn((num_matrices, 1, matrix_size, matrix_size), dtype=torch.float16) config = torch.tensor([matrix_size, num_matrices, is_lower, block_dim], dtype=torch.int64)