Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -4616,9 +4616,11 @@ pto.thistogram ins(<src>, <idx> : <src_type>, <idx_type>)
- `idx` rows and valid rows must match `src`.
- `idx` must have exactly one column.
- When `src` is `ui32`:
- `idx` must use `row_major + none_box` layout.
- `idx` cols and valid cols must match `src`.
- `idx` rows / valid rows must be `1` for `byte = 3` or `2`, `2` for `byte = 1`, and `3` for `byte = 0`.
- When `byte = 3`, `idx` is accepted but not semantically used by the A5 backend intrinsic; no additional layout or shape constraints are imposed beyond the generic `tile_buf`, `loc=vec`, `dtype=ui8`, and rank-2 requirements.
- When `byte = 2`, `1`, or `0`, `idx` must use `row_major + none_box` layout. `idx` cols and valid cols must match `src`.
- When `byte = 2`, `idx` rows / valid rows must be `1`.
- When `byte = 1`, `idx` rows / valid rows must be `2`.
- When `byte = 0`, `idx` rows / valid rows must be `3`.

**Hardware Mapping:**

Expand Down
36 changes: 20 additions & 16 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6887,22 +6887,26 @@ LogicalResult THistogramOp::verify() {
if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1]))
return emitOpError("expects idx to have exactly one column when src element type is ui16");
} else {
if (idxTB.getBLayoutValueI32() != static_cast<int32_t>(pto::BLayout::RowMajor) ||
idxTB.getSLayoutValueI32() != static_cast<int32_t>(pto::SLayout::NoneBox))
return emitOpError(
"expects idx to use row_major + none_box layout when src element type is ui32");
if (!hasCompatibleKnownExtent(srcShape[1], idxShape[1]) ||
!hasCompatibleKnownExtent(srcValid[1], idxValid[1]))
return emitOpError("expects idx cols and valid cols to match src when src element type is ui32");

int64_t expectedIdxRows = 1;
if (byte == 1)
expectedIdxRows = 2;
else if (byte == 0)
expectedIdxRows = 3;
if (!hasCompatibleKnownExtent(idxShape[0], expectedIdxRows) ||
!hasCompatibleKnownExtent(idxValid[0], expectedIdxRows))
return emitOpError("expects idx rows/valid rows to match the byte-selected filter depth when src element type is ui32");
if (byte != 3) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If byte is not validated to be within the valid range [0, 3] for ui32 inputs, an invalid value (such as 4 or negative values) could bypass the layout and shape checks, potentially leading to undefined behavior or compiler crashes downstream. Adding a range check for byte ensures robustness.

      if (byte < 0 || byte > 3)
        return emitOpError("expects byte to be in the range [0, 3] when src element type is ui32");

      if (byte != 3) {

if (idxTB.getBLayoutValueI32() != static_cast<int32_t>(pto::BLayout::RowMajor) ||
idxTB.getSLayoutValueI32() != static_cast<int32_t>(pto::SLayout::NoneBox))
return emitOpError(
"expects idx to use row_major + none_box layout when src element type is ui32 and byte is 0, 1, or 2");
if (!hasCompatibleKnownExtent(srcShape[1], idxShape[1]) ||
!hasCompatibleKnownExtent(srcValid[1], idxValid[1]))
return emitOpError(
"expects idx cols and valid cols to match src when src element type is ui32 and byte is 0, 1, or 2");

int64_t expectedIdxRows = 1;
if (byte == 1)
expectedIdxRows = 2;
else if (byte == 0)
expectedIdxRows = 3;
if (!hasCompatibleKnownExtent(idxShape[0], expectedIdxRows) ||
!hasCompatibleKnownExtent(idxValid[0], expectedIdxRows))
return emitOpError(
"expects idx rows/valid rows to match the byte-selected filter depth when src element type is ui32 and byte is 0, 1, or 2");
}
}
if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256)
return emitOpError("expects dst shape[1] to be at least 256");
Expand Down
14 changes: 14 additions & 0 deletions test/lit/pto/thistogram_verify_u32_byte3_unused_idx_a5.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s

module {
func.func @thistogram_verify_u32_byte3_unused_idx_a5() {
%src = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=ui32, rows=8, cols=32, v_row=8, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%idx = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=ui8, rows=32, cols=11, v_row=32, v_col=11, blayout=col_major, slayout=none_box, fractal=512, pad=0>
%dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=ui32, rows=8, cols=256, v_row=8, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.thistogram ins(%src, %idx : !pto.tile_buf<loc=vec, dtype=ui32, rows=8, cols=32, v_row=8, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=ui8, rows=32, cols=11, v_row=32, v_col=11, blayout=col_major, slayout=none_box, fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=ui32, rows=8, cols=256, v_row=8, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {byte = 3 : i32}
return
}
}

// CHECK-NOT: error:
Loading