From 2caba08c1af16b9aa972e9eb6c7595371650a351 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Sch=C3=BCrmann?= Date: Thu, 3 Sep 2020 11:59:00 +0100 Subject: [PATCH] aco: fix VOP3P assembly, VN and validation aco/opcodes: rename v_pk_fma_mix* -> v_fma_mix* and add modifier capabilities for VOP3P. Reviewed-by: Rhys Perry Part-of: --- src/amd/compiler/aco_assembler.cpp | 4 +- src/amd/compiler/aco_opcodes.py | 49 ++++++++++---------- src/amd/compiler/aco_opt_value_numbering.cpp | 12 +++++ src/amd/compiler/aco_validate.cpp | 7 +-- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/amd/compiler/aco_assembler.cpp b/src/amd/compiler/aco_assembler.cpp index 0840d8af410..aed70afb28f 100644 --- a/src/amd/compiler/aco_assembler.cpp +++ b/src/amd/compiler/aco_assembler.cpp @@ -612,7 +612,7 @@ void emit_instruction(asm_context& ctx, std::vector& out, Instruction* encoding |= opcode << 16; encoding |= (vop3->clamp ? 1 : 0) << 15; encoding |= vop3->opsel_lo << 11; - encoding |= (vop3->opsel_hi & 0x4) ? 1 : 0 << 14; + encoding |= ((vop3->opsel_hi & 0x4) ? 1 : 0) << 14; for (unsigned i = 0; i < 3; i++) encoding |= vop3->neg_hi[i] << (8+i); encoding |= (0xFF & instr->definitions[0].physReg()); @@ -620,7 +620,7 @@ void emit_instruction(asm_context& ctx, std::vector& out, Instruction* encoding = 0; for (unsigned i = 0; i < instr->operands.size(); i++) encoding |= instr->operands[i].physReg() << (i * 9); - encoding |= vop3->opsel_hi & 0x3 << 27; + encoding |= (vop3->opsel_hi & 0x3) << 27; for (unsigned i = 0; i < 3; i++) encoding |= vop3->neg_lo[i] << (29+i); out.push_back(encoding); diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py index 178e448662b..82b47e8e1eb 100644 --- a/src/amd/compiler/aco_opcodes.py +++ b/src/amd/compiler/aco_opcodes.py @@ -903,33 +903,34 @@ for i in range(8): # VOPP instructions: packed 16bit instructions - 1 or 2 inputs and 1 output VOPP = { - (0x00, "v_pk_mad_i16"), - (0x01, "v_pk_mul_lo_u16"), - (0x02, "v_pk_add_i16"), - (0x03, "v_pk_sub_i16"), - (0x04, "v_pk_lshlrev_b16"), - (0x05, "v_pk_lshrrev_b16"), - (0x06, "v_pk_ashrrev_i16"), - (0x07, "v_pk_max_i16"), - (0x08, "v_pk_min_i16"), - (0x09, "v_pk_mad_u16"), - (0x0a, "v_pk_add_u16"), - (0x0b, "v_pk_sub_u16"), - (0x0c, "v_pk_max_u16"), - (0x0d, "v_pk_min_u16"), - (0x0e, "v_pk_fma_f16"), - (0x0f, "v_pk_add_f16"), - (0x10, "v_pk_mul_f16"), - (0x11, "v_pk_min_f16"), - (0x12, "v_pk_max_f16"), - (0x20, "v_pk_fma_mix_f32"), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA - (0x21, "v_pk_fma_mixlo_f16"), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA - (0x22, "v_pk_fma_mixhi_f16"), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA + # opcode, name, input/output modifiers + (0x00, "v_pk_mad_i16", False), + (0x01, "v_pk_mul_lo_u16", False), + (0x02, "v_pk_add_i16", False), + (0x03, "v_pk_sub_i16", False), + (0x04, "v_pk_lshlrev_b16", False), + (0x05, "v_pk_lshrrev_b16", False), + (0x06, "v_pk_ashrrev_i16", False), + (0x07, "v_pk_max_i16", False), + (0x08, "v_pk_min_i16", False), + (0x09, "v_pk_mad_u16", False), + (0x0a, "v_pk_add_u16", False), + (0x0b, "v_pk_sub_u16", False), + (0x0c, "v_pk_max_u16", False), + (0x0d, "v_pk_min_u16", False), + (0x0e, "v_pk_fma_f16", True), + (0x0f, "v_pk_add_f16", True), + (0x10, "v_pk_mul_f16", True), + (0x11, "v_pk_min_f16", True), + (0x12, "v_pk_max_f16", True), + (0x20, "v_fma_mix_f32", True), # v_mad_mix_f32 in VEGA ISA, v_fma_mix_f32 in RDNA ISA + (0x21, "v_fma_mixlo_f16", True), # v_mad_mixlo_f16 in VEGA ISA, v_fma_mixlo_f16 in RDNA ISA + (0x22, "v_fma_mixhi_f16", True), # v_mad_mixhi_f16 in VEGA ISA, v_fma_mixhi_f16 in RDNA ISA } # note that these are only supported on gfx9+ so we'll need to distinguish between gfx8 and gfx9 here # (gfx6, gfx7, gfx8, gfx9, gfx10, name) = (-1, -1, -1, code, code, name) -for (code, name) in VOPP: - opcode(name, -1, code, code, Format.VOP3P) +for (code, name, modifiers) in VOPP: + opcode(name, -1, code, code, Format.VOP3P, modifiers, modifiers) # VINTERP instructions: diff --git a/src/amd/compiler/aco_opt_value_numbering.cpp b/src/amd/compiler/aco_opt_value_numbering.cpp index 8dc2812bc7e..2d1a69b1492 100644 --- a/src/amd/compiler/aco_opt_value_numbering.cpp +++ b/src/amd/compiler/aco_opt_value_numbering.cpp @@ -244,6 +244,18 @@ struct InstrPred { return false; return true; } + case Format::VOP3P: { + VOP3P_instruction* a3P = static_cast(a); + VOP3P_instruction* b3P = static_cast(b); + for (unsigned i = 0; i < 3; i++) { + if (a3P->neg_lo[i] != b3P->neg_lo[i] || + a3P->neg_hi[i] != b3P->neg_hi[i]) + return false; + } + return a3P->opsel_lo == b3P->opsel_lo && + a3P->opsel_hi == b3P->opsel_hi && + a3P->clamp == b3P->clamp; + } case Format::PSEUDO_REDUCTION: { Pseudo_reduction_instruction *aR = static_cast(a); Pseudo_reduction_instruction *bR = static_cast(b); diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index 0f19a1752b5..bf45cc7cb0e 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -240,12 +240,13 @@ bool validate_ir(Program* program) instr->format == Format::VOP1 || instr->format == Format::VOP2 || instr->format == Format::VOPC || - (instr->isVOP3() && program->chip_class >= GFX10), + (instr->isVOP3() && program->chip_class >= GFX10) || + (instr->format == Format::VOP3P && program->chip_class >= GFX10), "Literal applied on wrong instruction format", instr.get()); check(literal.isUndefined() || (literal.size() == op.size() && literal.constantValue() == op.constantValue()), "Only 1 Literal allowed", instr.get()); literal = op; - check(!instr->isVALU() || instr->isVOP3() || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get()); + check(instr->isSALU() || instr->isVOP3() || instr->format == Format::VOP3P || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get()); } /* check num sgprs for VALU */ @@ -257,7 +258,7 @@ bool validate_ir(Program* program) if (program->chip_class >= GFX10 && !is_shift64) const_bus_limit = 2; - uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5; + uint32_t scalar_mask = instr->isVOP3() || instr->format == Format::VOP3P ? 0x7 : 0x5; if (instr->isSDWA()) scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;