From fb6ae2eac1fd883d69cca3cdb26c50a7272496dc Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Fri, 11 Apr 2025 10:36:12 -0700 Subject: [PATCH] spirv: Refactor to use glsl_type to pick ALU ops Reviewed-by: Ian Romanick Reviewed-by: Georg Lehmann Part-of: --- src/compiler/spirv/spirv_to_nir.c | 12 ++++++------ src/compiler/spirv/vtn_alu.c | 18 ++++++++---------- src/compiler/spirv/vtn_cmat.c | 11 ++++++----- src/compiler/spirv/vtn_private.h | 3 ++- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 5dad3060d13..f52b52ea799 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2677,8 +2677,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, default: { bool swap; - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(val->type->type); - nir_alu_type src_alu_type = dst_alu_type; + + const glsl_type *dst_type = val->type->type; + const glsl_type *src_type = dst_type; + unsigned num_components = glsl_get_vector_elements(val->type->type); vtn_assert(count <= 7); @@ -2688,8 +2690,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, case SpvOpFConvert: case SpvOpUConvert: /* We have a different source type in a conversion. */ - src_alu_type = - nir_get_nir_type_for_glsl_type(vtn_get_value_type(b, w[4])->type); + src_type = vtn_get_value_type(b, w[4])->type; break; default: break; @@ -2697,8 +2698,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, bool exact; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, - nir_alu_type_get_type_size(src_alu_type), - nir_alu_type_get_type_size(dst_alu_type)); + src_type, dst_type); /* No SPIR-V opcodes handled through this path should set exact. * Since it is ignored, assert on it. diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index e3c10f51dc4..f0b459a53d0 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -280,8 +280,12 @@ vtn_convert_op_dst_type(SpvOp opcode) nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, SpvOp opcode, bool *swap, bool *exact, - unsigned src_bit_size, unsigned dst_bit_size) + const glsl_type *src_type, + const glsl_type *dst_type) { + const unsigned src_bit_size = glsl_get_bit_size(src_type); + const unsigned dst_bit_size = glsl_get_bit_size(dst_type); + /* Indicates that the first two arguments should be swapped. This is * used for implementing greater-than and less-than-or-equal. */ @@ -890,11 +894,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFUnordGreaterThanEqual: { bool swap; bool unused_exact; - unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &unused_exact, - src_bit_size, dst_bit_size); + vtn_src[0]->type, dest_type); if (swap) { nir_def *tmp = src[0]; @@ -969,10 +971,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpShiftRightLogical: { bool swap; bool exact; - unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, - src0_bit_size, dst_bit_size); + vtn_src[0]->type, dest_type); assert(!exact); @@ -1029,11 +1029,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, default: { bool swap; bool exact; - unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, - src_bit_size, dst_bit_size); + vtn_src[0]->type, dest_type); if (swap) { nir_def *tmp = src[0]; diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index c5909e842d5..746c58e2eea 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -223,12 +223,10 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, struct vtn_type *dst_type = vtn_get_type(b, w[1]); nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]); - unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type)); - unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type)); - bool ignored = false; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, - src_bit_size, dst_bit_size); + glsl_get_cmat_element(src->type), + glsl_get_cmat_element(dst_type->type)); nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary"); nir_cmat_unary_op(&b->nb, &dst->def, &src->def, @@ -247,12 +245,15 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, case SpvOpSDiv: case SpvOpUDiv: { bool ignored = false; - nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0); struct vtn_type *dst_type = vtn_get_type(b, w[1]); nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]); nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, + glsl_get_cmat_element(mat_a->type), + glsl_get_cmat_element(dst_type->type)); + nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary"); nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def, .alu_op = op); diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 5656c3c6dd3..3b83618778c 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -959,7 +959,8 @@ nir_alu_type vtn_convert_op_dst_type(SpvOp opcode); nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, SpvOp opcode, bool *swap, bool *exact, - unsigned src_bit_size, unsigned dst_bit_size); + const glsl_type *src_type, + const glsl_type *dst_type); void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count);