Skip to content
6 changes: 6 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub enum SpirvAttribute {
Builtin(BuiltIn),
DescriptorSet(u32),
Binding(u32),
Location(u32),
Flat,
PerPrimitiveExt,
Invariant,
Expand Down Expand Up @@ -130,6 +131,7 @@ pub struct AggregatedSpirvAttributes {
pub builtin: Option<Spanned<BuiltIn>>,
pub descriptor_set: Option<Spanned<u32>>,
pub binding: Option<Spanned<u32>>,
pub location: Option<Spanned<u32>>,
pub flat: Option<Spanned<()>>,
pub invariant: Option<Spanned<()>>,
pub per_primitive_ext: Option<Spanned<()>>,
Expand Down Expand Up @@ -216,6 +218,7 @@ impl AggregatedSpirvAttributes {
"#[spirv(descriptor_set)]",
),
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
Location(value) => try_insert(&mut self.location, value, span, "#[spirv(location)]"),
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
PerPrimitiveExt => try_insert(
Expand Down Expand Up @@ -323,6 +326,7 @@ impl CheckSpirvAttrVisitor<'_> {
| SpirvAttribute::Builtin(_)
| SpirvAttribute::DescriptorSet(_)
| SpirvAttribute::Binding(_)
| SpirvAttribute::Location(_)
| SpirvAttribute::Flat
| SpirvAttribute::Invariant
| SpirvAttribute::PerPrimitiveExt
Expand Down Expand Up @@ -602,6 +606,8 @@ fn parse_spirv_attr<'a>(
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.binding) {
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.location) {
SpirvAttribute::Location(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.input_attachment_index) {
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
} else if arg.has_name(sym.spec_constant) {
Expand Down
81 changes: 66 additions & 15 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,41 @@ impl<'tcx> CodegenCx<'tcx> {
.name(var_id.or(spec_const_id).unwrap(), ident.to_string());
}

// location assignment
// Note(@firestar99): UniformConstant are things like `SampledImage`, `StorageImage`, `Sampler` and
// `Acceleration structure`. Almost always they are assigned a `descriptor_set` and binding, thus never end up
// here being assigned locations. I think this is one of those occasions where spirv allows us to assign
// locations, but the "client API" Vulkan doesn't describe any use-case for them, or at least none I'm aware of.
// A quick scour through the spec revealed that `VK_KHR_dynamic_rendering_local_read` may need this, and while
// we don't support it yet (I assume), I'll just keep it here in case it becomes useful in the future.
let has_location = matches!(
storage_class,
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
let mut assign_location = |var_id: Result<Word, &str>, explicit: Option<u32>| {
let location = decoration_locations
.entry(storage_class.unwrap())
.or_insert_with(|| 0);
if let Some(explicit) = explicit {
*location = explicit;
}
self.emit_global().decorate(
var_id.unwrap(),
Decoration::Location,
std::iter::once(Operand::LiteralBit32(*location)),
);
let spirv_type = self.lookup_type(value_spirv_type);
if let Some(location_size) = spirv_type.location_size(self) {
*location += location_size;
} else {
*location += 1;
self.tcx.dcx().span_err(
hir_param.ty_span,
"Type not supported in Input or Output declarations",
);
}
};

// Emit `OpDecorate`s based on attributes.
let mut decoration_supersedes_location = false;
if let Some(builtin) = attrs.builtin {
Expand Down Expand Up @@ -757,6 +792,35 @@ impl<'tcx> CodegenCx<'tcx> {
);
decoration_supersedes_location = true;
}
if let Some(location) = attrs.location {
if let Err(SpecConstant { .. }) = storage_class {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot apply to `#[spirv(spec_constant)]`",
);
}
if attrs.descriptor_set.is_some() {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(descriptor_set = ...)]`",
);
}
if attrs.binding.is_some() {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` cannot be combined with `#[spirv(binding = ...)]`",
);
}
if !has_location {
self.tcx.dcx().span_fatal(
location.span,
"`#[spirv(location = ...)]` can only be used on Inputs (declared as plain values, eg. `Vec4`) \
or Outputs (declared as mut ref, eg. `&mut Vec4`)",
);
}
assign_location(var_id, Some(location.value));
decoration_supersedes_location = true;
}
if let Some(flat) = attrs.flat {
if let Err(SpecConstant { .. }) = storage_class {
self.tcx.dcx().span_fatal(
Expand Down Expand Up @@ -867,21 +931,8 @@ impl<'tcx> CodegenCx<'tcx> {
// individually.
// TODO: Is this right for UniformConstant? Do they share locations with
// input/outpus?
let has_location = !decoration_supersedes_location
&& matches!(
storage_class,
Ok(StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant)
);
if has_location {
let location = decoration_locations
.entry(storage_class.unwrap())
.or_insert_with(|| 0);
self.emit_global().decorate(
var_id.unwrap(),
Decoration::Location,
std::iter::once(Operand::LiteralBit32(*location)),
);
*location += 1;
if !decoration_supersedes_location && has_location {
assign_location(var_id, None);
}

match storage_class {
Expand Down
55 changes: 48 additions & 7 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,42 @@ impl SpirvType<'_> {
id
}

/// Returns how many Input / Output `location`s this type occupies, or None if this type is not allowed to be sent.
///
/// See [Vulkan Spec 16.1.4. Location and Component Assignment](https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#interfaces-iointerfaces-locations)
#[allow(clippy::match_same_arms)]
pub fn location_size(&self, cx: &CodegenCx<'_>) -> Option<u32> {
let result = match *self {
// bools cannot be in an Input / Output interface
Self::Bool => return None,
Self::Integer(_, _) | Self::Float(_) => 1,
Self::Vector { .. } => 1,
Self::Adt { field_types, .. } => {
let mut locations = 0;
for f in field_types {
locations += cx.lookup_type(*f).location_size(cx)?;
}
locations
}
Self::Matrix { element, count } => cx
.lookup_type(element)
.location_size(cx)?
.checked_mul(count)
.expect("overflow"),
Self::Array { element, count } => {
let element = cx.lookup_type(element).location_size(cx)?;
let count = cx
.builder
.lookup_const_scalar(count)
.and_then(|c| u32::try_from(c).ok())
.expect("SpirvType::Array.count to be a u32 constant");
element.checked_mul(count).expect("overflow")
}
_ => return None,
};
Some(result)
}

pub fn sizeof(&self, cx: &CodegenCx<'_>) -> Option<Size> {
let result = match *self {
// Types that have a dynamic size, or no concept of size at all.
Expand All @@ -285,14 +321,19 @@ impl SpirvType<'_> {
Self::Integer(width, _) | Self::Float(width) => Size::from_bits(width),
Self::Adt { size, .. } => size?,
Self::Vector { size, .. } => size,
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
Self::Matrix { element, count } => cx
.lookup_type(element)
.sizeof(cx)?
.checked_mul(count as u64, cx)
.expect("overflow"),
Self::Array { element, count } => {
cx.lookup_type(element).sizeof(cx)?
* cx.builder
.lookup_const_scalar(count)
.unwrap()
.try_into()
.unwrap()
let element = cx.lookup_type(element).sizeof(cx)?;
let count = cx
.builder
.lookup_const_scalar(count)
.and_then(|c| u64::try_from(c).ok())
.expect("SpirvType::Array.count to be a u32 constant");
element.checked_mul(count, cx).expect("overflow")
}
Self::Pointer { .. } => cx.tcx.data_layout.pointer_size,
Self::Image { .. }
Expand Down
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct Symbols {

pub descriptor_set: Symbol,
pub binding: Symbol,
pub location: Symbol,
pub input_attachment_index: Symbol,

pub spec_constant: Symbol,
Expand Down Expand Up @@ -420,6 +421,7 @@ impl Symbols {

descriptor_set: Symbol::intern("descriptor_set"),
binding: Symbol::intern("binding"),
location: Symbol::intern("location"),
input_attachment_index: Symbol::intern("input_attachment_index"),

spec_constant: Symbol::intern("spec_constant"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ LL | #![feature(ptr_internals)]
= note: using it is strongly discouraged
= note: `#[warn(internal_features)]` on by default

error: pointer has non-null integer address
|
note: used from within `allocate_const_scalar::main`
--> $DIR/allocate_const_scalar.rs:16:5
|
LL | *output = POINTER;
| ^^^^^^^^^^^^^^^^^
note: called by `main`
--> $DIR/allocate_const_scalar.rs:15:8
error: Type not supported in Input or Output declarations
--> $DIR/allocate_const_scalar.rs:15:21
|
LL | pub fn main(output: &mut Unique<()>) {
| ^^^^
| ^^^^^^^^^^^^^^^

error: aborting due to 1 previous error; 1 warning emitted

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use spirv_std::{Image, spirv};

#[spirv(vertex)]
pub fn main(
#[spirv(uniform)] error: &Image!(2D, type=f32),
#[spirv(uniform_constant)] warning: &Image!(2D, type=f32),
#[spirv(descriptor_set = 0, binding = 0, uniform)] error: &Image!(2D, type=f32),
#[spirv(descriptor_set = 0, binding = 1, uniform_constant)] warning: &Image!(2D, type=f32),
) {
}

// https://github.com/EmbarkStudios/rust-gpu/issues/585
#[spirv(vertex)]
pub fn issue_585(invalid: Image!(2D, type=f32)) {}
pub fn issue_585(#[spirv(descriptor_set = 0, binding = 0)] invalid: Image!(2D, type=f32)) {}
22 changes: 11 additions & 11 deletions tests/compiletests/ui/spirv-attr/bad-deduce-storage-class.stderr
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
error: storage class mismatch
--> $DIR/bad-deduce-storage-class.rs:8:5
|
LL | #[spirv(uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^^-------^^^^^^^^^^---------------------
| | |
| | `UniformConstant` deduced from type
| `Uniform` specified in attribute
LL | #[spirv(descriptor_set = 0, binding = 0, uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^-------^^^^^^^^^^---------------------
| | |
| | `UniformConstant` deduced from type
| `Uniform` specified in attribute
|
= help: remove storage class attribute to use `UniformConstant` as storage class

warning: redundant storage class attribute, storage class is deduced from type
--> $DIR/bad-deduce-storage-class.rs:9:13
--> $DIR/bad-deduce-storage-class.rs:9:46
|
LL | #[spirv(uniform_constant)] warning: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^
LL | #[spirv(descriptor_set = 0, binding = 1, uniform_constant)] warning: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^

error: entry parameter type must be by-reference: `&spirv_std::image::Image<f32, 1, 2, 0, 0, 0, 0, 4>`
--> $DIR/bad-deduce-storage-class.rs:15:27
--> $DIR/bad-deduce-storage-class.rs:15:69
|
LL | pub fn issue_585(invalid: Image!(2D, type=f32)) {}
| ^^^^^^^^^^^^^^^^^^^^
LL | pub fn issue_585(#[spirv(descriptor_set = 0, binding = 0)] invalid: Image!(2D, type=f32)) {}
| ^^^^^^^^^^^^^^^^^^^^
|
= note: this error originates in the macro `Image` (in Nightly builds, run with -Z macro-backtrace for more info)

Expand Down
14 changes: 13 additions & 1 deletion tests/compiletests/ui/spirv-attr/bool-inputs-err.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@ error: entry-point parameter cannot contain `bool`s
LL | input: bool,
| ^^^^

error: Type not supported in Input or Output declarations
--> $DIR/bool-inputs-err.rs:13:12
|
LL | input: bool,
| ^^^^

error: entry-point parameter cannot contain `bool`s
--> $DIR/bool-inputs-err.rs:14:13
|
LL | output: &mut bool,
| ^^^^^^^^^

error: Type not supported in Input or Output declarations
--> $DIR/bool-inputs-err.rs:14:13
|
LL | output: &mut bool,
| ^^^^^^^^^

error: entry-point parameter cannot contain `bool`s
--> $DIR/bool-inputs-err.rs:15:35
|
Expand All @@ -22,5 +34,5 @@ error: entry-point parameter cannot contain `bool`s
LL | #[spirv(uniform)] uniform: &Boolthing,
| ^^^^^^^^^^

error: aborting due to 4 previous errors
error: aborting due to 6 previous errors

31 changes: 31 additions & 0 deletions tests/compiletests/ui/spirv-attr/location_assignment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "; .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

use spirv_std::glam::*;
use spirv_std::{Image, spirv};

#[derive(Copy, Clone, Default)]
pub struct LargerThanVec4 {
a: Vec4,
b: Vec2,
}

#[spirv(vertex)]
pub fn main(out1: &mut LargerThanVec4, out2: &mut Vec2, out3: &mut Mat4, out4: &mut f32) {
*out1 = Default::default();
*out2 = Default::default();
*out3 = Default::default();
*out4 = Default::default();
}
Loading