diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f52b52ea799..d4e2537124d 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -1886,10 +1886,26 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode, case SpvOpTypeFloat: { int bit_size = w[2]; val->type->base_type = vtn_base_type_scalar; - vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64, - "Invalid float bit size: %u", bit_size); - val->type->type = glsl_floatN_t_type(bit_size); val->type->length = 1; + + int32_t encoding = count > 3 ? w[3] : -1; + switch (encoding) { + case -1: + /* No encoding specified, it is a regular FP. */ + vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64, + "Invalid float bit size: %u", bit_size); + val->type->type = glsl_floatN_t_type(bit_size); + break; + + case SpvFPEncodingBFloat16KHR: + vtn_fail_if(bit_size != 16, + "Invalid Bfloat16 bit size: %u", bit_size); + val->type->type = glsl_bfloatN_t_type(bit_size); + break; + + default: + vtn_fail("Unsupported OpTypeFloat encoding: %d", encoding); + } break; }