diff --git a/src/amd/compiler/aco_insert_fp_mode.cpp b/src/amd/compiler/aco_insert_fp_mode.cpp index e74f2334b29..cab133abc26 100644 --- a/src/amd/compiler/aco_insert_fp_mode.cpp +++ b/src/amd/compiler/aco_insert_fp_mode.cpp @@ -289,6 +289,16 @@ emit_set_mode_block(fp_mode_ctx* ctx, Block* block) } else if (instr->opcode == aco_opcode::p_v_cvt_pk_fp8_f32_ovfl) { set_mode |= fp_state.require(mode_fp16_ovfl, 1); instr->opcode = aco_opcode::v_cvt_pk_fp8_f32; + } else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) { + set_mode |= fp_state.require(mode_round16_64, fp_round_tz); + set_mode |= fp_state.require(mode_round32, default_state.fields[mode_round32]); + set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); + set_mode |= fp_state.require(mode_denorm32, default_state.fields[mode_denorm32]); + if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) + instr->opcode = aco_opcode::v_fma_mixlo_f16; + else + instr->opcode = aco_opcode::v_fma_mixhi_f16; } else { mode_mask default_needs = instr_default_needs(ctx, block, instr); u_foreach_bit (i, default_needs) diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index 404ab6bb326..d43a37a2631 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -438,6 +438,8 @@ can_use_DPP(amd_gfx_level gfx_level, const aco_ptr& instr, bool dpp return instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mixhi_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz || instr->opcode == aco_opcode::v_dot2_f32_f16 || instr->opcode == aco_opcode::v_dot2_f32_bf16; } @@ -644,6 +646,8 @@ instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op) case aco_opcode::v_interp_p2_hi_f16: case aco_opcode::v_fma_mixlo_f16: case aco_opcode::v_fma_mixhi_f16: + case aco_opcode::p_v_fma_mixlo_f16_rtz: + case aco_opcode::p_v_fma_mixhi_f16_rtz: /* VOP2 */ case aco_opcode::v_mac_f16: case aco_opcode::v_madak_f16: @@ -861,7 +865,9 @@ get_operand_type(aco_ptr& alu, unsigned index) aco_type type = instr_info.alu_opcode_infos[(int)alu->opcode].op_types[index]; if (alu->opcode == aco_opcode::v_fma_mix_f32 || alu->opcode == aco_opcode::v_fma_mixlo_f16 || - alu->opcode == aco_opcode::v_fma_mixhi_f16) + alu->opcode == aco_opcode::v_fma_mixhi_f16 || + alu->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + alu->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) type.bit_size = alu->valu().opsel_hi[index] ? 16 : 32; return type; @@ -1154,6 +1160,8 @@ get_swapped_opcode(aco_opcode opcode, unsigned idx0, unsigned idx1) case aco_opcode::v_fma_mix_f32: case aco_opcode::v_fma_mixlo_f16: case aco_opcode::v_fma_mixhi_f16: + case aco_opcode::p_v_fma_mixlo_f16_rtz: + case aco_opcode::p_v_fma_mixhi_f16_rtz: case aco_opcode::v_pk_fmac_f16: { if (idx1 == 2) return aco_opcode::num_opcodes; diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py index 5ca1abe6a01..005d5dfd280 100644 --- a/src/amd/compiler/aco_opcodes.py +++ b/src/amd/compiler/aco_opcodes.py @@ -1228,6 +1228,8 @@ VOPP = { ("v_fma_mix_f32", dst(F32), src(F32, F32, F32), op(gfx9=0x20)), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA ("v_fma_mixlo_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x21)), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA ("v_fma_mixhi_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x22)), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA + ("p_v_fma_mixlo_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixlo_f16 with fp16 rtz rounding + ("p_v_fma_mixhi_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixhi_f16 with fp16 rtz rounding ("v_dot2_i32_i16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x26, gfx10=0x14, gfx11=-1)), ("v_dot2_u32_u16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x27, gfx10=0x15, gfx11=-1)), ("v_dot4_i32_iu8", dst(U32), src(PkU16, PkU16, U32), op(gfx11=0x16)), diff --git a/src/amd/compiler/aco_print_ir.cpp b/src/amd/compiler/aco_print_ir.cpp index 416e27899ba..be7febbb93f 100644 --- a/src/amd/compiler/aco_print_ir.cpp +++ b/src/amd/compiler/aco_print_ir.cpp @@ -1044,7 +1044,9 @@ aco_print_instr(enum amd_gfx_level gfx_level, const Instruction* instr, FILE* ou if (instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16 || - instr->opcode == aco_opcode::v_fma_mixhi_f16) { + instr->opcode == aco_opcode::v_fma_mixhi_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) { const VALU_instruction& vop3p = instr->valu(); abs = vop3p.abs; neg = vop3p.neg; diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 6003293991d..f24a6902860 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -771,6 +771,7 @@ DefInfo::get_subdword_definition_info(Program* program, const aco_ptropcode) ? v2b : v1; stride = 4; if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || can_use_opsel(gfx_level, instr->opcode, -1)) { data_stride = 2; stride = rc == v2b ? 2 : stride; @@ -861,6 +862,9 @@ add_subdword_definition(Program* program, aco_ptr& instr, PhysReg r if (instr->opcode == aco_opcode::v_fma_mixlo_f16) { instr->opcode = aco_opcode::v_fma_mixhi_f16; return; + } else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) { + instr->opcode = aco_opcode::p_v_fma_mixhi_f16_rtz; + return; } if (convert_bitwise_to_16bit(instr.get())) { diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index 63fe8e7b156..d8183c54512 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -404,6 +404,8 @@ validate_ir(Program* program) check(!valu.opsel[3], "Unexpected opsel for sub-dword definition", instr.get()); } else if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mixhi_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz || instr->opcode == aco_opcode::v_fma_mix_f32) { check(instr->definitions[0].regClass() == (instr->opcode == aco_opcode::v_fma_mix_f32 ? v1 : v2b), @@ -1348,6 +1350,8 @@ validate_subdword_operand(amd_gfx_level gfx_level, const aco_ptr& i if (instr->isVOP3P()) { bool fma_mix = instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mixhi_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz || instr->opcode == aco_opcode::v_fma_mix_f32; return instr->valu().opsel_lo[index] == (byte >> 1) && instr->valu().opsel_hi[index] == (fma_mix || (byte >> 1)); @@ -1411,6 +1415,7 @@ validate_subdword_definition(amd_gfx_level gfx_level, const aco_ptr switch (instr->opcode) { case aco_opcode::v_interp_p2_hi_f16: case aco_opcode::v_fma_mixhi_f16: + case aco_opcode::p_v_fma_mixhi_f16_rtz: case aco_opcode::buffer_load_ubyte_d16_hi: case aco_opcode::buffer_load_sbyte_d16_hi: case aco_opcode::buffer_load_short_d16_hi: