spirv: Refactor to use glsl_type to pick ALU ops

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-04-11 10:36:12 -07:00 committed by Marge Bot
parent bba607ac2b
commit fb6ae2eac1
4 changed files with 22 additions and 22 deletions

View file

@ -2677,8 +2677,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
default: {
bool swap;
nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(val->type->type);
nir_alu_type src_alu_type = dst_alu_type;
const glsl_type *dst_type = val->type->type;
const glsl_type *src_type = dst_type;
unsigned num_components = glsl_get_vector_elements(val->type->type);
vtn_assert(count <= 7);
@ -2688,8 +2690,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
case SpvOpFConvert:
case SpvOpUConvert:
/* We have a different source type in a conversion. */
src_alu_type =
nir_get_nir_type_for_glsl_type(vtn_get_value_type(b, w[4])->type);
src_type = vtn_get_value_type(b, w[4])->type;
break;
default:
break;
@ -2697,8 +2698,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
bool exact;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
nir_alu_type_get_type_size(src_alu_type),
nir_alu_type_get_type_size(dst_alu_type));
src_type, dst_type);
/* No SPIR-V opcodes handled through this path should set exact.
* Since it is ignored, assert on it.

View file

@ -280,8 +280,12 @@ vtn_convert_op_dst_type(SpvOp opcode)
nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
SpvOp opcode, bool *swap, bool *exact,
unsigned src_bit_size, unsigned dst_bit_size)
const glsl_type *src_type,
const glsl_type *dst_type)
{
const unsigned src_bit_size = glsl_get_bit_size(src_type);
const unsigned dst_bit_size = glsl_get_bit_size(dst_type);
/* Indicates that the first two arguments should be swapped. This is
* used for implementing greater-than and less-than-or-equal.
*/
@ -890,11 +894,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
case SpvOpFUnordGreaterThanEqual: {
bool swap;
bool unused_exact;
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
&unused_exact,
src_bit_size, dst_bit_size);
vtn_src[0]->type, dest_type);
if (swap) {
nir_def *tmp = src[0];
@ -969,10 +971,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
case SpvOpShiftRightLogical: {
bool swap;
bool exact;
unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
src0_bit_size, dst_bit_size);
vtn_src[0]->type, dest_type);
assert(!exact);
@ -1029,11 +1029,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
default: {
bool swap;
bool exact;
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
&exact,
src_bit_size, dst_bit_size);
vtn_src[0]->type, dest_type);
if (swap) {
nir_def *tmp = src[0];

View file

@ -223,12 +223,10 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
bool ignored = false;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
src_bit_size, dst_bit_size);
glsl_get_cmat_element(src->type),
glsl_get_cmat_element(dst_type->type));
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
@ -247,12 +245,15 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
case SpvOpSDiv:
case SpvOpUDiv: {
bool ignored = false;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
struct vtn_type *dst_type = vtn_get_type(b, w[1]);
nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
glsl_get_cmat_element(mat_a->type),
glsl_get_cmat_element(dst_type->type));
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
.alu_op = op);

View file

@ -959,7 +959,8 @@ nir_alu_type vtn_convert_op_dst_type(SpvOp opcode);
nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
SpvOp opcode, bool *swap, bool *exact,
unsigned src_bit_size, unsigned dst_bit_size);
const glsl_type *src_type,
const glsl_type *dst_type);
void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);