radv,aco,nir: keep the A and B base type for cmat_muladd_amd

With bfloat16, and the two fp8 formats in the future, using just the bit size
to identify the types is no longer possible.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34768>
This commit is contained in:
Georg Lehmann 2025-04-30 11:22:53 +02:00 committed by Marge Bot
parent c21e1776b3
commit e8f5c335ff
3 changed files with 25 additions and 9 deletions

View file

@ -8020,20 +8020,24 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
bitarray8 neg_lo = nir_intrinsic_neg_lo_amd(instr);
bitarray8 neg_hi = nir_intrinsic_neg_hi_amd(instr);
switch (instr->src[0].ssa->bit_size) {
case 16:
enum glsl_base_type type_a = nir_intrinsic_src_base_type(instr);
enum glsl_base_type type_b = nir_intrinsic_src_base_type2(instr);
switch (type_a) {
case GLSL_TYPE_FLOAT16:
switch (instr->def.bit_size) {
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break;
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
}
break;
case 8: {
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8: {
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
unsigned signed_mask = nir_intrinsic_cmat_signed_mask(instr);
neg_lo[0] = signed_mask & NIR_CMAT_A_SIGNED;
neg_lo[1] = signed_mask & NIR_CMAT_B_SIGNED;
neg_lo[0] = type_a == GLSL_TYPE_INT8;
neg_lo[1] = type_b == GLSL_TYPE_INT8;
break;
}
default: unreachable("invalid cmat_muladd_amd type");
}
if (opcode == aco_opcode::num_opcodes)

View file

@ -417,11 +417,22 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *B = radv_nir_load_cmat(&b, &params, intr->src[2].ssa);
nir_def *C = radv_nir_load_cmat(&b, &params, intr->src[3].ssa);
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *a_deref = nir_src_as_deref(intr->src[1]);
nir_deref_instr *b_deref = nir_src_as_deref(intr->src[2]);
struct glsl_cmat_description a_desc = *glsl_get_cmat_description(a_deref->type);
struct glsl_cmat_description b_desc = *glsl_get_cmat_description(b_deref->type);
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
enum glsl_base_type a_element_type =
glsl_apply_signedness_to_base_type(a_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_base_type b_element_type =
glsl_apply_signedness_to_base_type(b_desc.element_type, cmat_signed_mask & NIR_CMAT_B_SIGNED);
nir_def *ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
.src_base_type = a_element_type, .src_base_type2 = b_element_type);
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
nir_instr_remove(instr);
progress = true;

View file

@ -226,6 +226,7 @@ index("nir_alu_type", "dest_type")
# Source and destination data types for dpas_intel. Needed here to
# represent types that won't have a nir_alu_type.
index("enum glsl_base_type", "src_base_type")
index("enum glsl_base_type", "src_base_type2")
index("enum glsl_base_type", "dest_base_type")
# The swizzle mask for quad_swizzle_amd & masked_swizzle_amd
@ -1982,7 +1983,7 @@ intrinsic("strict_wqm_coord_amd", src_comp=[0], dest_comp=0, bit_sizes=[32], ind
flags=[CAN_ELIMINATE])
intrinsic("cmat_muladd_amd", src_comp=[-1, -1, 0], dest_comp=0, bit_sizes=src2,
indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, CMAT_SIGNED_MASK], flags=[CAN_ELIMINATE])
indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, SRC_BASE_TYPE, SRC_BASE_TYPE2], flags=[CAN_ELIMINATE])
# Get the debug log buffer descriptor.
intrinsic("load_debug_log_desc_amd", bit_sizes=[32], dest_comp=4, flags=[CAN_ELIMINATE, CAN_REORDER])