Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions src/compiler/codegen/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,26 @@ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false)
ctx.debug_emitter, ctx.sci, inst.ssa_idx; linkage_name=ln)
end
s = inst[:stmt]
if s isa ControlFlowOp
emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx)
else
emit_statement!(ctx, s, inst.ssa_idx, value_type(inst))
# Per-statement diagnostic boundary: an `IRError` from anywhere in
# this statement's emission is recorded (with its kernel-side stack)
# and the result poisoned, so emission continues and one compile can
# report every problem. See compiler/codegen/errors.jl.
ctx.current_ssa_idx = inst.ssa_idx
ctx.touched_poison = false
try
if s isa ControlFlowOp
emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx)
else
emit_statement!(ctx, s, inst.ssa_idx, value_type(inst))
end
catch e
e isa IRError || rethrow()
ctx.current_ssa_idx = inst.ssa_idx
# Suppress errors derived purely from an already-poisoned input;
# keep only root causes.
ctx.touched_poison || record_error!(ctx, e.msg)
push!(ctx.poisoned, inst.ssa_idx)
ctx.values[inst.ssa_idx] = poison_value(ctx)
end
end
if !skip_terminator && terminator(block) !== nothing
Expand Down Expand Up @@ -62,16 +78,19 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), ss
cond_tv === nothing && throw(IRError("Cannot resolve condition for IfOp"))

# Determine result types from parent_result_type (token_order_pass! already
# updated the type to include any token carries via inst[:type] = …)
# updated the type to include any token carries via inst[:type] = …).
#
# A non-representable result (typically `Any`/`Union{}` left when inference
# could not pin a branch's tile down) is recorded but does NOT abort: we
# substitute a poison type and still emit both regions, so the root cause
# inside the branch (e.g. a non-constant tile shape) surfaces too. The IfOp
# result is then marked poison to suppress the cascade at its consumers.
result_types = TypeId[]
result_poisoned = false
if parent_result_type !== Nothing
if parent_result_type <: Tuple
for T in parent_result_type.parameters
push!(result_types, tile_type_for_julia!(ctx, T))
end
else
push!(result_types, tile_type_for_julia!(ctx, parent_result_type))
end
Ts = parent_result_type <: Tuple ? collect(parent_result_type.parameters) :
Any[parent_result_type]
result_types, result_poisoned = collect_result_types!(ctx, Ts; context="`if`/`else` result")
end

then_body = function(_)
Expand All @@ -89,6 +108,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), ss
results = encode_IfOp!(then_body, else_body, cb, result_types, cond_tv.v)

ctx.values[ssa_idx] = CGVal(results, parent_result_type)
result_poisoned && push!(ctx.poisoned, ssa_idx)
end

#=============================================================================
Expand Down Expand Up @@ -120,9 +140,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
push!(init_values, tv.v)
end

# Build result types uniformly from block args
# Build result types uniformly from block args. A non-representable carry
# (typically `Any`/`Union{}`) is recorded but does not abort: poison it and
# still emit the body so the root cause inside the loop surfaces too.
n_carries = length(body_blk.args)
result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args]
result_types, result_poisoned =
collect_result_types!(ctx, (arg.type for arg in body_blk.args); context="loop carry")

body_builder = function(block_args)
saved = copy(ctx.block_args)
Expand All @@ -146,6 +169,7 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
lower_tv.v, upper_tv.v, step_tv.v, init_values)

ctx.values[ssa_idx] = CGVal(results, parent_result_type)
result_poisoned && push!(ctx.poisoned, ssa_idx)
end

#=============================================================================
Expand All @@ -164,7 +188,8 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
end

n_carries = length(body_blk.args)
result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args]
result_types, result_poisoned =
collect_result_types!(ctx, (arg.type for arg in body_blk.args); context="loop carry")

body_builder = function(block_args)
saved = copy(ctx.block_args)
Expand All @@ -188,6 +213,7 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
results = encode_LoopOp!(body_builder, cb, result_types, init_values)

ctx.values[ssa_idx] = CGVal(results, parent_result_type)
result_poisoned && push!(ctx.poisoned, ssa_idx)
end

#=============================================================================
Expand All @@ -212,7 +238,8 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
end

n_carries = length(before_blk.args)
result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in before_blk.args]
result_types, result_poisoned =
collect_result_types!(ctx, (arg.type for arg in before_blk.args); context="loop carry")

body_builder = function(block_args)
saved = copy(ctx.block_args)
Expand Down Expand Up @@ -301,6 +328,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
results = encode_LoopOp!(body_builder, cb, result_types, init_values)

ctx.values[ssa_idx] = CGVal(results, parent_result_type)
result_poisoned && push!(ctx.poisoned, ssa_idx)
end

#=============================================================================
Expand Down
9 changes: 9 additions & 0 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
# Emit the structured IR (uses original Julia SSA indices everywhere)
emit_block!(ctx, ctx.sci.entry)

# Raise any deferred diagnostics accumulated during emission before we
# commit a (necessarily incomplete) function body.
report_errors!(ctx)

finalize_function!(func_buf, cb, writer.debug_info)
end

Expand Down Expand Up @@ -436,6 +440,11 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
# 5. Emit body (skip terminator — we yield manually)
emit_block!(sub_ctx, sci.entry; skip_terminator=true)

# Subprograms compile in their own context; lift any deferred diagnostics
# (e.g. an unsupported op inside a reduce/scan combiner) into the parent so
# the top-level `report_errors!` surfaces them with their kernel-side stack.
append!(ctx.errors, sub_ctx.errors)

# 6. Extract return value and yield
ret = terminator(sci.entry)::ReturnNode
tv = emit_value!(sub_ctx, ret.val)
Expand Down
3 changes: 3 additions & 0 deletions src/compiler/codegen/values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
Emit/resolve a value reference to a CGVal using multiple dispatch.
"""
function emit_value!(ctx::CGCtx, ssa::SSAValue)
# Reading a poisoned result marks the current statement as a cascade, so its
# (derived) failure is suppressed in favour of the root cause.
ssa.id in ctx.poisoned && (ctx.touched_poison = true)
tv = ctx[ssa]
tv !== nothing || throw(IRError("SSAValue %$(ssa.id) not found in context"))
return tv
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/transform/canonicalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ function promote_scalar_type(@nospecialize(T))
T <: Number && return Tile{T, Tuple{}}
if T <: Tuple
params = T.parameters
# A `Vararg` tail (`Tuple{Vararg{…}}`) has a non-`Type` parameter that
# `widenconst` chokes on; such tuples are never concrete tiles anyway,
# so leave them for the codegen-side diagnostic to reject cleanly.
any(Base.isvarargtype, params) && return nothing
any_promoted = false
new_params = map(params) do P
P = CC.widenconst(P)
Expand Down
30 changes: 30 additions & 0 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ struct IRError <: Exception
end
Base.showerror(io::IO, e::IRError) = print(io, "IRError: ", e.msg)

"""
CodegenError(msg, stack)

A single deferred codegen diagnostic: a message plus the kernel-side
inlining stack (`source_location`, ordered outermost→innermost) of the
statement that produced it. Instead of throwing on the first unsupported
construct, codegen accumulates these in `CGCtx.errors` and `report_errors!`
raises them together at the end of `emit_kernel!`.
"""
struct CodegenError
msg::String
stack::Vector{SourceLocation}
end

#=============================================================================
CGVal: Unified value representation (analogous to Julia's jl_cgval_t)
=============================================================================#
Expand Down Expand Up @@ -322,6 +336,19 @@ mutable struct CGCtx
# `tuple_element_source` and other parent-walking queries can start
# from the right scope. `nothing` when no block has been entered yet.
current_block::Any

# Deferred codegen diagnostics. Rather than aborting on the first
# unsupported construct, codegen catches each `IRError` at the per-
# statement boundary (`emit_block!`), records it here with the offending
# statement's kernel-side inlining stack, and continues with a poison
# placeholder, so one compile surfaces all problems (cf. GPUCompiler's
# `InvalidIRError` accumulation). `report_errors!` raises the aggregate.
errors::Vector{CodegenError}
# SSA indices whose emission failed; their results are poison. A consumer
# that reads a poisoned value sets `touched_poison`, letting the boundary
# handler drop the cascading (derived) error and keep only the root cause.
poisoned::Set{Int}
touched_poison::Bool
end

function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode,
Expand Down Expand Up @@ -354,6 +381,9 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode,
nothing, # bounds_info — set by run_passes!
Dict{Value, Value}(), # assume_wrapped
nothing, # current_block — set by emit_block!
CodegenError[], # errors, accumulated by record_error!
Set{Int}(), # poisoned
false, # touched_poison
)
end

Expand Down
1 change: 1 addition & 0 deletions src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ include("compiler/transform/licm.jl")
include("compiler/transform/dce.jl")
include("compiler/transform/pipeline.jl")
include("compiler/codegen/debug.jl")
include("compiler/codegen/errors.jl")
include("compiler/codegen/kernel.jl")
include("compiler/codegen/control_flow.jl")
include("compiler/codegen/statements.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/codegen/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ end
@testset "binary op type mismatch errors in Julia" begin
# addi with mismatched types (Int32 + Int64) should fail if the
# result is used. Dead code is removed by DCE before codegen.
@test_throws ct.IRError code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a
@test_throws ct.CodegenErrors code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a
pid = ct.bid(1) # Int32
# Force type mismatch by calling addi with different types
result = ct.Intrinsics.addi(pid, Int64(1))
Expand Down
8 changes: 4 additions & 4 deletions test/codegen/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ spec4d = ct.ArraySpec{4}(16, true)
end

@testset "vec-vec throws error" begin
@test_throws cuTile.IRError begin
@test_throws cuTile.CodegenErrors begin
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
bidx = ct.bid(1)
tile_a = ct.load(a, bidx, (16,))
Expand Down Expand Up @@ -1182,7 +1182,7 @@ end

# pack/unpack require v13.3 — older bytecode rejects with a clear error.
# (`literal` since the `+` in the message is a regex metachar to FileCheck.)
@test @filecheck throws=ct.IRError begin
@test @filecheck throws=ct.CodegenErrors begin
@check literal=true "v13.3+"
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}};
bytecode_version=v"13.2") do a, b
Expand All @@ -1194,7 +1194,7 @@ end
end

# Rank-1 scaled: one UInt8 (8 bits) can't fill a UInt16; caught by unpack.
@test @filecheck throws=ct.IRError begin
@test @filecheck throws=ct.CodegenErrors begin
@check "do not evenly divide"
code_tiled(Tuple{ct.TileArray{UInt8,1,spec1d}, ct.TileArray{UInt16,1,spec1d}}) do a, b
pid = ct.bid(1)
Expand All @@ -1204,7 +1204,7 @@ end
end

# reshape-widen: leading dim must equal the ratio (2); 1 fails the final reshape.
@test @filecheck throws=ct.IRError begin
@test @filecheck throws=ct.CodegenErrors begin
@check "same number of elements"
code_tiled(Tuple{ct.TileArray{UInt8,2,spec2d}, ct.TileArray{UInt16,2,spec2d}}) do a, b
pid = ct.bid(1)
Expand Down
2 changes: 1 addition & 1 deletion test/device/tile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ end
end

@testset "batched mat-vec (3D × 1D) errors" begin
@test_throws cuTile.IRError begin
@test_throws cuTile.CodegenErrors begin
ct.code_tiled(Tuple{ct.TileArray{Float32,3,ct.ArraySpec{3}(16,true)},
ct.TileArray{Float32,1,ct.ArraySpec{1}(16,true)}}) do a, b
bidx = ct.bid(1)
Expand Down