mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-24 11:00:11 +01:00
radv,aco,nir: keep the A and B base type for cmat_muladd_amd
With bfloat16, and the two fp8 formats in the future, using just the bit size to identify the types is no longer possible. Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34768>
This commit is contained in:
parent
c21e1776b3
commit
e8f5c335ff
3 changed files with 25 additions and 9 deletions
|
|
@ -8020,20 +8020,24 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
|||
bitarray8 neg_lo = nir_intrinsic_neg_lo_amd(instr);
|
||||
bitarray8 neg_hi = nir_intrinsic_neg_hi_amd(instr);
|
||||
|
||||
switch (instr->src[0].ssa->bit_size) {
|
||||
case 16:
|
||||
enum glsl_base_type type_a = nir_intrinsic_src_base_type(instr);
|
||||
enum glsl_base_type type_b = nir_intrinsic_src_base_type2(instr);
|
||||
|
||||
switch (type_a) {
|
||||
case GLSL_TYPE_FLOAT16:
|
||||
switch (instr->def.bit_size) {
|
||||
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break;
|
||||
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
|
||||
}
|
||||
break;
|
||||
case 8: {
|
||||
case GLSL_TYPE_UINT8:
|
||||
case GLSL_TYPE_INT8: {
|
||||
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
|
||||
unsigned signed_mask = nir_intrinsic_cmat_signed_mask(instr);
|
||||
neg_lo[0] = signed_mask & NIR_CMAT_A_SIGNED;
|
||||
neg_lo[1] = signed_mask & NIR_CMAT_B_SIGNED;
|
||||
neg_lo[0] = type_a == GLSL_TYPE_INT8;
|
||||
neg_lo[1] = type_b == GLSL_TYPE_INT8;
|
||||
break;
|
||||
}
|
||||
default: unreachable("invalid cmat_muladd_amd type");
|
||||
}
|
||||
|
||||
if (opcode == aco_opcode::num_opcodes)
|
||||
|
|
|
|||
|
|
@ -417,11 +417,22 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
nir_def *B = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa);
|
||||
nir_def *C = radv_nir_load_cmat(&b, ¶ms, intr->src[3].ssa);
|
||||
|
||||
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
nir_deref_instr *a_deref = nir_src_as_deref(intr->src[1]);
|
||||
nir_deref_instr *b_deref = nir_src_as_deref(intr->src[2]);
|
||||
struct glsl_cmat_description a_desc = *glsl_get_cmat_description(a_deref->type);
|
||||
struct glsl_cmat_description b_desc = *glsl_get_cmat_description(b_deref->type);
|
||||
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
|
||||
|
||||
enum glsl_base_type a_element_type =
|
||||
glsl_apply_signedness_to_base_type(a_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
enum glsl_base_type b_element_type =
|
||||
glsl_apply_signedness_to_base_type(b_desc.element_type, cmat_signed_mask & NIR_CMAT_B_SIGNED);
|
||||
|
||||
nir_def *ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr),
|
||||
.cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr));
|
||||
.src_base_type = a_element_type, .src_base_type2 = b_element_type);
|
||||
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
|
||||
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ index("nir_alu_type", "dest_type")
|
|||
# Source and destination data types for dpas_intel. Needed here to
|
||||
# represent types that won't have a nir_alu_type.
|
||||
index("enum glsl_base_type", "src_base_type")
|
||||
index("enum glsl_base_type", "src_base_type2")
|
||||
index("enum glsl_base_type", "dest_base_type")
|
||||
|
||||
# The swizzle mask for quad_swizzle_amd & masked_swizzle_amd
|
||||
|
|
@ -1982,7 +1983,7 @@ intrinsic("strict_wqm_coord_amd", src_comp=[0], dest_comp=0, bit_sizes=[32], ind
|
|||
flags=[CAN_ELIMINATE])
|
||||
|
||||
intrinsic("cmat_muladd_amd", src_comp=[-1, -1, 0], dest_comp=0, bit_sizes=src2,
|
||||
indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, CMAT_SIGNED_MASK], flags=[CAN_ELIMINATE])
|
||||
indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, SRC_BASE_TYPE, SRC_BASE_TYPE2], flags=[CAN_ELIMINATE])
|
||||
|
||||
# Get the debug log buffer descriptor.
|
||||
intrinsic("load_debug_log_desc_amd", bit_sizes=[32], dest_comp=4, flags=[CAN_ELIMINATE, CAN_REORDER])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue