From bba607ac2b6fd14ce33bbff44fc1684e41541a7a Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Fri, 14 Mar 2025 11:09:00 -0700 Subject: [PATCH] spirv: Move Convert opcodes handling to its own function Take the opportunity to add a comment about why the bit_size comes from the NIR def and not the original type. Reviewed-by: Ian Romanick Reviewed-by: Rohan Garg Reviewed-by: Georg Lehmann Part-of: --- src/compiler/spirv/vtn_alu.c | 81 +++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index caa2ad9a1ac..e3c10f51dc4 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -660,6 +660,52 @@ vtn_handle_deriv(struct vtn_builder *b, SpvOp opcode, nir_def *src) } } +static nir_def * +vtn_handle_convert(struct vtn_builder *b, SpvOp opcode, + struct vtn_value *dest_val, + const struct glsl_type *glsl_dest_type, + nir_def *src) +{ + /* Use bit_size from NIR source instead of from the original src type, + * to account for mediump_16bit. See vtn_handle_alu() for details. + */ + unsigned src_bit_size = src->bit_size; + unsigned dst_bit_size = glsl_get_bit_size(glsl_dest_type); + nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size; + nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size; + + struct conversion_opts opts = { + .rounding_mode = nir_rounding_mode_undef, + .saturate = false, + }; + vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts); + + if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS) + opts.saturate = true; + + nir_def *result; + + if (b->shader->info.stage == MESA_SHADER_KERNEL) { + if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) { + result = nir_type_convert(&b->nb, src, src_type, dst_type, + nir_rounding_mode_undef); + } else { + result = nir_convert_alu_types(&b->nb, dst_bit_size, src, + src_type, dst_type, + opts.rounding_mode, opts.saturate); + } + } else { + vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef && + dst_type != nir_type_float16, + "Rounding modes are only allowed on conversions to " + "16-bit float types"); + result = nir_type_convert(&b->nb, src, src_type, dst_type, + opts.rounding_mode); + } + + return result; +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -911,40 +957,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpSConvert: case SpvOpFConvert: case SpvOpSatConvertSToU: - case SpvOpSatConvertUToS: { - unsigned src_bit_size = src[0]->bit_size; - unsigned dst_bit_size = glsl_get_bit_size(dest_type); - nir_alu_type src_type = vtn_convert_op_src_type(opcode) | src_bit_size; - nir_alu_type dst_type = vtn_convert_op_dst_type(opcode) | dst_bit_size; - - struct conversion_opts opts = { - .rounding_mode = nir_rounding_mode_undef, - .saturate = false, - }; - vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts); - - if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS) - opts.saturate = true; - - if (b->shader->info.stage == MESA_SHADER_KERNEL) { - if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) { - dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type, - nir_rounding_mode_undef); - } else { - dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0], - src_type, dst_type, - opts.rounding_mode, opts.saturate); - } - } else { - vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef && - dst_type != nir_type_float16, - "Rounding modes are only allowed on conversions to " - "16-bit float types"); - dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type, - opts.rounding_mode); - } + case SpvOpSatConvertUToS: + dest->def = vtn_handle_convert(b, opcode, dest_val, dest_type, src[0]); break; - } case SpvOpBitFieldInsert: case SpvOpBitFieldSExtract: