Skip to content

Commit 9594680

Browse files
committed
promote u8 to u32 if needed
1 parent 3f7cb64 commit 9594680

3 files changed

Lines changed: 109 additions & 1 deletion

File tree

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ pub fn link(
492492
simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
493493
}
494494

495+
{
496+
let _timer = sess.timer("link_promote_int8_to_int32");
497+
simple_passes::promote_int8_to_int32(&mut output);
498+
}
499+
495500
// NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
496501
{
497502
let (spv_words, module_or_err, lower_from_spv_timer) =

crates/rustc_codegen_spirv/src/linker/simple_passes.rs

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{get_name, get_names};
2-
use rspirv::dr::{Block, Function, Module};
2+
use rspirv::dr::{Block, Function, Module, Operand};
33
use rspirv::spirv::{Decoration, ExecutionModel, Op, Word};
44
use rustc_codegen_spirv_types::Capability;
55
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -365,3 +365,79 @@ pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> s
365365
}
366366
Ok(())
367367
}
368+
369+
/// When `OpCapability Int8` is not declared, promote all implicit `i8`/`u8` types to `i32`/`u32`.
370+
pub fn promote_int8_to_int32(module: &mut Module) {
371+
let has_int8 = module.capabilities.iter().any(|inst| {
372+
inst.class.opcode == Op::Capability
373+
&& inst.operands[0].unwrap_capability() == Capability::Int8
374+
});
375+
if has_int8 {
376+
return;
377+
}
378+
379+
let narrow_types: FxHashMap<Word, u32> = module
380+
.types_global_values
381+
.iter()
382+
.filter_map(|inst| {
383+
if inst.class.opcode == Op::TypeInt && inst.operands[0].unwrap_literal_bit32() == 8 {
384+
let signedness = inst.operands[1].unwrap_literal_bit32();
385+
Some((inst.result_id?, signedness))
386+
} else {
387+
None
388+
}
389+
})
390+
.collect();
391+
392+
if narrow_types.is_empty() {
393+
return;
394+
}
395+
396+
// skip any 8-bit type that is used as the element type of an OpTypePointer.
397+
// such types are explicit interface/storage types chosen by the user
398+
let pointer_element_types: FxHashSet<Word> = module
399+
.types_global_values
400+
.iter()
401+
.filter_map(|inst| {
402+
if inst.class.opcode == Op::TypePointer {
403+
// operands: [StorageClass, element_type_id]
404+
Some(inst.operands[1].unwrap_id_ref())
405+
} else {
406+
None
407+
}
408+
})
409+
.collect();
410+
411+
let narrow_types: FxHashMap<Word, u32> = narrow_types
412+
.into_iter()
413+
.filter(|(id, _)| !pointer_element_types.contains(id))
414+
.collect();
415+
416+
if narrow_types.is_empty() {
417+
return;
418+
}
419+
420+
for inst in &mut module.types_global_values {
421+
// widen each 8-bit OpTypeInt to 32 bits
422+
if inst.class.opcode == Op::TypeInt
423+
&& let Some(id) = inst.result_id
424+
&& narrow_types.contains_key(&id)
425+
{
426+
inst.operands[0] = Operand::LiteralBit32(32);
427+
}
428+
429+
// fix OpConstant values: sign-extend signed 8-bit constants to 32 bits.
430+
if inst.class.opcode == Op::Constant
431+
&& let Some(ty) = inst.result_type
432+
&& let Some(&signedness) = narrow_types.get(&ty)
433+
&& let Operand::LiteralBit32(ref mut val) = inst.operands[0]
434+
{
435+
let narrow = *val as u8;
436+
*val = if signedness != 0 {
437+
(narrow as i8 as i32) as u32
438+
} else {
439+
narrow as u32
440+
};
441+
}
442+
}
443+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// build-pass
2+
//PartialOrd on CustomPosition(u32) internally returns Option<Ordering>,
3+
//where Ordering is represented as i8 in Rust's layout.
4+
//This caused rust-gpu to emit OpTypeInt 8 declarations requiring OpCapability Int8
5+
#![no_std]
6+
7+
use spirv_std::{glam::Vec4, spirv};
8+
9+
pub struct ShaderInputs {
10+
pub x: CustomPosition,
11+
pub y: CustomPosition,
12+
}
13+
14+
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq)]
15+
pub struct CustomPosition(u32);
16+
17+
#[spirv(vertex)]
18+
pub fn test_vs(
19+
#[spirv(push_constant)] inputs: &ShaderInputs,
20+
#[spirv(position)] out_pos: &mut Vec4,
21+
) {
22+
let mut result: f32 = 0.;
23+
if inputs.x < inputs.y {
24+
result = 1.0;
25+
}
26+
*out_pos = Vec4::new(inputs.x.0 as f32, inputs.y.0 as f32, result as f32, 1.0);
27+
}

0 commit comments

Comments
 (0)