aco: add fma_mix opcodes with rtz fp16 rounding

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38815>
This commit is contained in:
Georg Lehmann 2025-10-19 16:40:56 +02:00 committed by Marge Bot
parent af68c08e88
commit d6356191b9
6 changed files with 33 additions and 2 deletions

View file

@ -289,6 +289,16 @@ emit_set_mode_block(fp_mode_ctx* ctx, Block* block)
} else if (instr->opcode == aco_opcode::p_v_cvt_pk_fp8_f32_ovfl) {
set_mode |= fp_state.require(mode_fp16_ovfl, 1);
instr->opcode = aco_opcode::v_cvt_pk_fp8_f32;
} else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) {
set_mode |= fp_state.require(mode_round16_64, fp_round_tz);
set_mode |= fp_state.require(mode_round32, default_state.fields[mode_round32]);
set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]);
set_mode |= fp_state.require(mode_denorm32, default_state.fields[mode_denorm32]);
if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz)
instr->opcode = aco_opcode::v_fma_mixlo_f16;
else
instr->opcode = aco_opcode::v_fma_mixhi_f16;
} else {
mode_mask default_needs = instr_default_needs(ctx, block, instr);
u_foreach_bit (i, default_needs)

View file

@ -438,6 +438,8 @@ can_use_DPP(amd_gfx_level gfx_level, const aco_ptr<Instruction>& instr, bool dpp
return instr->opcode == aco_opcode::v_fma_mix_f32 ||
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
instr->opcode == aco_opcode::v_dot2_f32_f16 ||
instr->opcode == aco_opcode::v_dot2_f32_bf16;
}
@ -644,6 +646,8 @@ instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op)
case aco_opcode::v_interp_p2_hi_f16:
case aco_opcode::v_fma_mixlo_f16:
case aco_opcode::v_fma_mixhi_f16:
case aco_opcode::p_v_fma_mixlo_f16_rtz:
case aco_opcode::p_v_fma_mixhi_f16_rtz:
/* VOP2 */
case aco_opcode::v_mac_f16:
case aco_opcode::v_madak_f16:
@ -861,7 +865,9 @@ get_operand_type(aco_ptr<Instruction>& alu, unsigned index)
aco_type type = instr_info.alu_opcode_infos[(int)alu->opcode].op_types[index];
if (alu->opcode == aco_opcode::v_fma_mix_f32 || alu->opcode == aco_opcode::v_fma_mixlo_f16 ||
alu->opcode == aco_opcode::v_fma_mixhi_f16)
alu->opcode == aco_opcode::v_fma_mixhi_f16 ||
alu->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
alu->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz)
type.bit_size = alu->valu().opsel_hi[index] ? 16 : 32;
return type;
@ -1154,6 +1160,8 @@ get_swapped_opcode(aco_opcode opcode, unsigned idx0, unsigned idx1)
case aco_opcode::v_fma_mix_f32:
case aco_opcode::v_fma_mixlo_f16:
case aco_opcode::v_fma_mixhi_f16:
case aco_opcode::p_v_fma_mixlo_f16_rtz:
case aco_opcode::p_v_fma_mixhi_f16_rtz:
case aco_opcode::v_pk_fmac_f16: {
if (idx1 == 2)
return aco_opcode::num_opcodes;

View file

@ -1228,6 +1228,8 @@ VOPP = {
("v_fma_mix_f32", dst(F32), src(F32, F32, F32), op(gfx9=0x20)), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA
("v_fma_mixlo_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x21)), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA
("v_fma_mixhi_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x22)), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA
("p_v_fma_mixlo_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixlo_f16 with fp16 rtz rounding
("p_v_fma_mixhi_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixhi_f16 with fp16 rtz rounding
("v_dot2_i32_i16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x26, gfx10=0x14, gfx11=-1)),
("v_dot2_u32_u16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x27, gfx10=0x15, gfx11=-1)),
("v_dot4_i32_iu8", dst(U32), src(PkU16, PkU16, U32), op(gfx11=0x16)),

View file

@ -1044,7 +1044,9 @@ aco_print_instr(enum amd_gfx_level gfx_level, const Instruction* instr, FILE* ou
if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::v_fma_mixhi_f16) {
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) {
const VALU_instruction& vop3p = instr->valu();
abs = vop3p.abs;
neg = vop3p.neg;

View file

@ -771,6 +771,7 @@ DefInfo::get_subdword_definition_info(Program* program, const aco_ptr<Instructio
rc = instr_is_16bit(gfx_level, instr->opcode) ? v2b : v1;
stride = 4;
if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
can_use_opsel(gfx_level, instr->opcode, -1)) {
data_stride = 2;
stride = rc == v2b ? 2 : stride;
@ -861,6 +862,9 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
if (instr->opcode == aco_opcode::v_fma_mixlo_f16) {
instr->opcode = aco_opcode::v_fma_mixhi_f16;
return;
} else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) {
instr->opcode = aco_opcode::p_v_fma_mixhi_f16_rtz;
return;
}
if (convert_bitwise_to_16bit(instr.get())) {

View file

@ -404,6 +404,8 @@ validate_ir(Program* program)
check(!valu.opsel[3], "Unexpected opsel for sub-dword definition", instr.get());
} else if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
instr->opcode == aco_opcode::v_fma_mix_f32) {
check(instr->definitions[0].regClass() ==
(instr->opcode == aco_opcode::v_fma_mix_f32 ? v1 : v2b),
@ -1348,6 +1350,8 @@ validate_subdword_operand(amd_gfx_level gfx_level, const aco_ptr<Instruction>& i
if (instr->isVOP3P()) {
bool fma_mix = instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
instr->opcode == aco_opcode::v_fma_mix_f32;
return instr->valu().opsel_lo[index] == (byte >> 1) &&
instr->valu().opsel_hi[index] == (fma_mix || (byte >> 1));
@ -1411,6 +1415,7 @@ validate_subdword_definition(amd_gfx_level gfx_level, const aco_ptr<Instruction>
switch (instr->opcode) {
case aco_opcode::v_interp_p2_hi_f16:
case aco_opcode::v_fma_mixhi_f16:
case aco_opcode::p_v_fma_mixhi_f16_rtz:
case aco_opcode::buffer_load_ubyte_d16_hi:
case aco_opcode::buffer_load_sbyte_d16_hi:
case aco_opcode::buffer_load_short_d16_hi: