diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 28a1826b669..f5245645eb5 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -8020,20 +8020,24 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) bitarray8 neg_lo = nir_intrinsic_neg_lo_amd(instr); bitarray8 neg_hi = nir_intrinsic_neg_hi_amd(instr); - switch (instr->src[0].ssa->bit_size) { - case 16: + enum glsl_base_type type_a = nir_intrinsic_src_base_type(instr); + enum glsl_base_type type_b = nir_intrinsic_src_base_type2(instr); + + switch (type_a) { + case GLSL_TYPE_FLOAT16: switch (instr->def.bit_size) { case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break; case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break; } break; - case 8: { + case GLSL_TYPE_UINT8: + case GLSL_TYPE_INT8: { opcode = aco_opcode::v_wmma_i32_16x16x16_iu8; - unsigned signed_mask = nir_intrinsic_cmat_signed_mask(instr); - neg_lo[0] = signed_mask & NIR_CMAT_A_SIGNED; - neg_lo[1] = signed_mask & NIR_CMAT_B_SIGNED; + neg_lo[0] = type_a == GLSL_TYPE_INT8; + neg_lo[1] = type_b == GLSL_TYPE_INT8; break; } + default: unreachable("invalid cmat_muladd_amd type"); } if (opcode == aco_opcode::num_opcodes) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 2e9ee078679..084952bca53 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -417,11 +417,22 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev nir_def *B = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa); nir_def *C = radv_nir_load_cmat(&b, ¶ms, intr->src[3].ssa); - nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); + nir_deref_instr *a_deref = nir_src_as_deref(intr->src[1]); + nir_deref_instr *b_deref = nir_src_as_deref(intr->src[2]); + struct glsl_cmat_description a_desc = *glsl_get_cmat_description(a_deref->type); + struct glsl_cmat_description b_desc = *glsl_get_cmat_description(b_deref->type); + + const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr); + + enum glsl_base_type a_element_type = + glsl_apply_signedness_to_base_type(a_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); + enum glsl_base_type b_element_type = + glsl_apply_signedness_to_base_type(b_desc.element_type, cmat_signed_mask & NIR_CMAT_B_SIGNED); nir_def *ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), - .cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr)); + .src_base_type = a_element_type, .src_base_type2 = b_element_type); + nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]); nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components)); nir_instr_remove(instr); progress = true; diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 31d5a839dc3..fa1c71ae62b 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -226,6 +226,7 @@ index("nir_alu_type", "dest_type") # Source and destination data types for dpas_intel. Needed here to # represent types that won't have a nir_alu_type. index("enum glsl_base_type", "src_base_type") +index("enum glsl_base_type", "src_base_type2") index("enum glsl_base_type", "dest_base_type") # The swizzle mask for quad_swizzle_amd & masked_swizzle_amd @@ -1982,7 +1983,7 @@ intrinsic("strict_wqm_coord_amd", src_comp=[0], dest_comp=0, bit_sizes=[32], ind flags=[CAN_ELIMINATE]) intrinsic("cmat_muladd_amd", src_comp=[-1, -1, 0], dest_comp=0, bit_sizes=src2, - indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, CMAT_SIGNED_MASK], flags=[CAN_ELIMINATE]) + indices=[SATURATE, NEG_LO_AMD, NEG_HI_AMD, SRC_BASE_TYPE, SRC_BASE_TYPE2], flags=[CAN_ELIMINATE]) # Get the debug log buffer descriptor. intrinsic("load_debug_log_desc_amd", bit_sizes=[32], dest_comp=4, flags=[CAN_ELIMINATE, CAN_REORDER])