mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-06-15 15:38:22 +02:00
aco: add fma_mix opcodes with rtz fp16 rounding
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38815>
This commit is contained in:
parent
af68c08e88
commit
d6356191b9
6 changed files with 33 additions and 2 deletions
|
|
@ -289,6 +289,16 @@ emit_set_mode_block(fp_mode_ctx* ctx, Block* block)
|
|||
} else if (instr->opcode == aco_opcode::p_v_cvt_pk_fp8_f32_ovfl) {
|
||||
set_mode |= fp_state.require(mode_fp16_ovfl, 1);
|
||||
instr->opcode = aco_opcode::v_cvt_pk_fp8_f32;
|
||||
} else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) {
|
||||
set_mode |= fp_state.require(mode_round16_64, fp_round_tz);
|
||||
set_mode |= fp_state.require(mode_round32, default_state.fields[mode_round32]);
|
||||
set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]);
|
||||
set_mode |= fp_state.require(mode_denorm32, default_state.fields[mode_denorm32]);
|
||||
if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz)
|
||||
instr->opcode = aco_opcode::v_fma_mixlo_f16;
|
||||
else
|
||||
instr->opcode = aco_opcode::v_fma_mixhi_f16;
|
||||
} else {
|
||||
mode_mask default_needs = instr_default_needs(ctx, block, instr);
|
||||
u_foreach_bit (i, default_needs)
|
||||
|
|
|
|||
|
|
@ -438,6 +438,8 @@ can_use_DPP(amd_gfx_level gfx_level, const aco_ptr<Instruction>& instr, bool dpp
|
|||
return instr->opcode == aco_opcode::v_fma_mix_f32 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
|
||||
instr->opcode == aco_opcode::v_dot2_f32_f16 ||
|
||||
instr->opcode == aco_opcode::v_dot2_f32_bf16;
|
||||
}
|
||||
|
|
@ -644,6 +646,8 @@ instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op)
|
|||
case aco_opcode::v_interp_p2_hi_f16:
|
||||
case aco_opcode::v_fma_mixlo_f16:
|
||||
case aco_opcode::v_fma_mixhi_f16:
|
||||
case aco_opcode::p_v_fma_mixlo_f16_rtz:
|
||||
case aco_opcode::p_v_fma_mixhi_f16_rtz:
|
||||
/* VOP2 */
|
||||
case aco_opcode::v_mac_f16:
|
||||
case aco_opcode::v_madak_f16:
|
||||
|
|
@ -861,7 +865,9 @@ get_operand_type(aco_ptr<Instruction>& alu, unsigned index)
|
|||
aco_type type = instr_info.alu_opcode_infos[(int)alu->opcode].op_types[index];
|
||||
|
||||
if (alu->opcode == aco_opcode::v_fma_mix_f32 || alu->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
alu->opcode == aco_opcode::v_fma_mixhi_f16)
|
||||
alu->opcode == aco_opcode::v_fma_mixhi_f16 ||
|
||||
alu->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
alu->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz)
|
||||
type.bit_size = alu->valu().opsel_hi[index] ? 16 : 32;
|
||||
|
||||
return type;
|
||||
|
|
@ -1154,6 +1160,8 @@ get_swapped_opcode(aco_opcode opcode, unsigned idx0, unsigned idx1)
|
|||
case aco_opcode::v_fma_mix_f32:
|
||||
case aco_opcode::v_fma_mixlo_f16:
|
||||
case aco_opcode::v_fma_mixhi_f16:
|
||||
case aco_opcode::p_v_fma_mixlo_f16_rtz:
|
||||
case aco_opcode::p_v_fma_mixhi_f16_rtz:
|
||||
case aco_opcode::v_pk_fmac_f16: {
|
||||
if (idx1 == 2)
|
||||
return aco_opcode::num_opcodes;
|
||||
|
|
|
|||
|
|
@ -1228,6 +1228,8 @@ VOPP = {
|
|||
("v_fma_mix_f32", dst(F32), src(F32, F32, F32), op(gfx9=0x20)), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA
|
||||
("v_fma_mixlo_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x21)), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA
|
||||
("v_fma_mixhi_f16", dst(F16), src(F32, F32, F32), op(gfx9=0x22)), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA
|
||||
("p_v_fma_mixlo_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixlo_f16 with fp16 rtz rounding
|
||||
("p_v_fma_mixhi_f16_rtz", dst(F16), src(F32, F32, F32), op(-1)), # v_fma_mixhi_f16 with fp16 rtz rounding
|
||||
("v_dot2_i32_i16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x26, gfx10=0x14, gfx11=-1)),
|
||||
("v_dot2_u32_u16", dst(U32), src(PkU16, PkU16, U32), op(gfx9=0x27, gfx10=0x15, gfx11=-1)),
|
||||
("v_dot4_i32_iu8", dst(U32), src(PkU16, PkU16, U32), op(gfx11=0x16)),
|
||||
|
|
|
|||
|
|
@ -1044,7 +1044,9 @@ aco_print_instr(enum amd_gfx_level gfx_level, const Instruction* instr, FILE* ou
|
|||
|
||||
if (instr->opcode == aco_opcode::v_fma_mix_f32 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixhi_f16) {
|
||||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) {
|
||||
const VALU_instruction& vop3p = instr->valu();
|
||||
abs = vop3p.abs;
|
||||
neg = vop3p.neg;
|
||||
|
|
|
|||
|
|
@ -771,6 +771,7 @@ DefInfo::get_subdword_definition_info(Program* program, const aco_ptr<Instructio
|
|||
rc = instr_is_16bit(gfx_level, instr->opcode) ? v2b : v1;
|
||||
stride = 4;
|
||||
if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
can_use_opsel(gfx_level, instr->opcode, -1)) {
|
||||
data_stride = 2;
|
||||
stride = rc == v2b ? 2 : stride;
|
||||
|
|
@ -861,6 +862,9 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
|
|||
if (instr->opcode == aco_opcode::v_fma_mixlo_f16) {
|
||||
instr->opcode = aco_opcode::v_fma_mixhi_f16;
|
||||
return;
|
||||
} else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) {
|
||||
instr->opcode = aco_opcode::p_v_fma_mixhi_f16_rtz;
|
||||
return;
|
||||
}
|
||||
|
||||
if (convert_bitwise_to_16bit(instr.get())) {
|
||||
|
|
|
|||
|
|
@ -404,6 +404,8 @@ validate_ir(Program* program)
|
|||
check(!valu.opsel[3], "Unexpected opsel for sub-dword definition", instr.get());
|
||||
} else if (instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
|
||||
instr->opcode == aco_opcode::v_fma_mix_f32) {
|
||||
check(instr->definitions[0].regClass() ==
|
||||
(instr->opcode == aco_opcode::v_fma_mix_f32 ? v1 : v2b),
|
||||
|
|
@ -1348,6 +1350,8 @@ validate_subdword_operand(amd_gfx_level gfx_level, const aco_ptr<Instruction>& i
|
|||
if (instr->isVOP3P()) {
|
||||
bool fma_mix = instr->opcode == aco_opcode::v_fma_mixlo_f16 ||
|
||||
instr->opcode == aco_opcode::v_fma_mixhi_f16 ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz ||
|
||||
instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz ||
|
||||
instr->opcode == aco_opcode::v_fma_mix_f32;
|
||||
return instr->valu().opsel_lo[index] == (byte >> 1) &&
|
||||
instr->valu().opsel_hi[index] == (fma_mix || (byte >> 1));
|
||||
|
|
@ -1411,6 +1415,7 @@ validate_subdword_definition(amd_gfx_level gfx_level, const aco_ptr<Instruction>
|
|||
switch (instr->opcode) {
|
||||
case aco_opcode::v_interp_p2_hi_f16:
|
||||
case aco_opcode::v_fma_mixhi_f16:
|
||||
case aco_opcode::p_v_fma_mixhi_f16_rtz:
|
||||
case aco_opcode::buffer_load_ubyte_d16_hi:
|
||||
case aco_opcode::buffer_load_sbyte_d16_hi:
|
||||
case aco_opcode::buffer_load_short_d16_hi:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue