diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index dc864ede740..27a4270e2e4 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -3646,18 +3646,6 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, uint32_t spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery); } - if (s->info.bit_sizes_int & 8) - spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt8); - if (s->info.bit_sizes_int & 16) - spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt16); - if (s->info.bit_sizes_int & 64) - spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt64); - - if (s->info.bit_sizes_float & 16) - spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat16); - if (s->info.bit_sizes_float & 64) - spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat64); - ctx.stage = s->info.stage; ctx.so_info = so_info; ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450"); diff --git a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c index f859a3372e0..d93d47d11e4 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c +++ b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c @@ -1204,6 +1204,12 @@ SpvId spirv_builder_type_uint(struct spirv_builder *b, unsigned width) { uint32_t args[] = { width, 0 }; + if (width == 8) + spirv_builder_emit_cap(b, SpvCapabilityInt8); + else if (width == 16) + spirv_builder_emit_cap(b, SpvCapabilityInt16); + else if (width == 64) + spirv_builder_emit_cap(b, SpvCapabilityInt64); return get_type_def(b, SpvOpTypeInt, args, ARRAY_SIZE(args)); } @@ -1211,6 +1217,10 @@ SpvId spirv_builder_type_float(struct spirv_builder *b, unsigned width) { uint32_t args[] = { width }; + if (width == 16) + spirv_builder_emit_cap(b, SpvCapabilityFloat16); + else if (width == 64) + spirv_builder_emit_cap(b, SpvCapabilityFloat64); return get_type_def(b, SpvOpTypeFloat, args, ARRAY_SIZE(args)); } @@ -1430,6 +1440,12 @@ SpvId spirv_builder_const_uint(struct spirv_builder *b, int width, uint64_t val) { assert(width >= 8); + if (width == 8) + spirv_builder_emit_cap(b, SpvCapabilityInt8); + else if (width == 16) + spirv_builder_emit_cap(b, SpvCapabilityInt16); + else if (width == 64) + spirv_builder_emit_cap(b, SpvCapabilityInt64); SpvId type = spirv_builder_type_uint(b, width); if (width <= 32) return emit_constant_32(b, type, val); @@ -1449,12 +1465,15 @@ spirv_builder_const_float(struct spirv_builder *b, int width, double val) { assert(width >= 16); SpvId type = spirv_builder_type_float(b, width); - if (width == 16) + if (width == 16) { + spirv_builder_emit_cap(b, SpvCapabilityFloat16); return emit_constant_32(b, type, _mesa_float_to_half(val)); - else if (width == 32) + } else if (width == 32) return emit_constant_32(b, type, u_bitcast_f2u(val)); - else if (width == 64) + else if (width == 64) { + spirv_builder_emit_cap(b, SpvCapabilityFloat64); return emit_constant_64(b, type, u_bitcast_d2u(val)); + } unreachable("unhandled float-width"); }