spirv: Move Convert opcodes handling to its own function

Take the opportunity to add a comment about why the bit_size
comes from the NIR def and not the original type.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Rohan Garg <rohan.garg@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-03-14 11:09:00 -07:00 committed by Marge Bot
parent d4381c0908
commit bba607ac2b

View file

@ -660,6 +660,52 @@ vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src)
}
}
static nir_def *
vtn_handle_convert(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *dest_val,
const struct glsl_type *glsl_dest_type,
nir_def *src)
{
/* Use bit_size from NIR source instead of from the original src type,
* to account for mediump_16bit. See vtn_handle_alu() for details.
*/
unsigned src_bit_size = src->bit_size;
unsigned dst_bit_size = glsl_get_bit_size(glsl_dest_type);
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
struct conversion_opts opts = {
.rounding_mode = nir_rounding_mode_undef,
.saturate = false,
};
vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
opts.saturate = true;
nir_def *result;
if (b->shader->info.stage == MESA_SHADER_KERNEL) {
if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
result = nir_type_convert(&b->nb, src, src_type, dst_type,
nir_rounding_mode_undef);
} else {
result = nir_convert_alu_types(&b->nb, dst_bit_size, src,
src_type, dst_type,
opts.rounding_mode, opts.saturate);
}
} else {
vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
dst_type != nir_type_float16,
"Rounding modes are only allowed on conversions to "
"16-bit float types");
result = nir_type_convert(&b->nb, src, src_type, dst_type,
opts.rounding_mode);
}
return result;
}
void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -911,40 +957,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
case SpvOpSConvert:
case SpvOpFConvert:
case SpvOpSatConvertSToU:
case SpvOpSatConvertUToS: {
unsigned src_bit_size = src[0]->bit_size;
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size;
nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size;
struct conversion_opts opts = {
.rounding_mode = nir_rounding_mode_undef,
.saturate = false,
};
vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
opts.saturate = true;
if (b->shader->info.stage == MESA_SHADER_KERNEL) {
if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
nir_rounding_mode_undef);
} else {
dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
src_type, dst_type,
opts.rounding_mode, opts.saturate);
}
} else {
vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
dst_type != nir_type_float16,
"Rounding modes are only allowed on conversions to "
"16-bit float types");
dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
opts.rounding_mode);
}
case SpvOpSatConvertUToS:
dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type, src[0]);
break;
}
case SpvOpBitFieldInsert:
case SpvOpBitFieldSExtract: