From 2a7fa132be1ced3fa94503051807eb3bcb5ac0be Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Mon, 2 Aug 2021 19:42:44 +0100 Subject: [PATCH] aco: implement udot_4x8/sdot_4x8/udot_2x16/sdot_2x16 opcodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Timur Kristóf Part-of: --- .../compiler/aco_instruction_selection.cpp | 50 +++++++++++++++++++ .../aco_instruction_selection_setup.cpp | 10 +++- src/amd/compiler/aco_opcodes.py | 5 ++ src/amd/compiler/aco_register_allocation.cpp | 7 ++- 4 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 5b17f449df2..04c6a7216d3 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -960,6 +960,24 @@ emit_vop3p_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, T return res; } +void +emit_idot_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst, bool clamp) +{ + Temp src[3] = {Temp(0, v1), Temp(0, v1), Temp(0, v1)}; + bool has_sgpr = false; + for (unsigned i = 0; i < 3; i++) { + src[i] = get_alu_src(ctx, instr->src[i]); + if (has_sgpr) + src[i] = as_vgpr(ctx, src[i]); + else + has_sgpr = src[i].type() == RegType::sgpr; + } + + Builder bld(ctx->program, ctx->block); + bld.is_precise = instr->exact; + bld.vop3p(op, Definition(dst), src[0], src[1], src[2], 0x0, 0x7).instr->vop3p().clamp = clamp; +} + void emit_vop1_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst) { @@ -2112,6 +2130,38 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } break; } + case nir_op_sdot_4x8_iadd: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, false); + break; + } + case nir_op_sdot_4x8_iadd_sat: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, true); + break; + } + case nir_op_udot_4x8_uadd: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, false); + break; + } + case nir_op_udot_4x8_uadd_sat: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, true); + break; + } + case nir_op_sdot_2x16_iadd: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, false); + break; + } + case nir_op_sdot_2x16_iadd_sat: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, true); + break; + } + case nir_op_udot_2x16_uadd: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, false); + break; + } + case nir_op_udot_2x16_uadd_sat: { + emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, true); + break; + } case nir_op_cube_face_coord_amd: { Temp in = get_alu_src(ctx, instr->src[0], 3); Temp src[3] = {emit_extract_vector(ctx, in, 0, v1), emit_extract_vector(ctx, in, 1, v1), diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index 5371628d7ac..e9b333110d7 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -594,7 +594,15 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_op_cube_face_index_amd: case nir_op_cube_face_coord_amd: case nir_op_sad_u8x4: - case nir_op_iadd_sat: type = RegType::vgpr; break; + case nir_op_iadd_sat: + case nir_op_udot_4x8_uadd: + case nir_op_sdot_4x8_iadd: + case nir_op_udot_4x8_uadd_sat: + case nir_op_sdot_4x8_iadd_sat: + case nir_op_udot_2x16_uadd: + case nir_op_sdot_2x16_iadd: + case nir_op_udot_2x16_uadd_sat: + case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break; case nir_op_f2i16: case nir_op_f2u16: case nir_op_f2i32: diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py index 426c97c2069..bb027180456 100644 --- a/src/amd/compiler/aco_opcodes.py +++ b/src/amd/compiler/aco_opcodes.py @@ -680,6 +680,7 @@ VOP2 = { (0x0a, 0x0a, 0x07, 0x07, 0x0a, "v_mul_hi_i32_i24", False), (0x0b, 0x0b, 0x08, 0x08, 0x0b, "v_mul_u32_u24", False), (0x0c, 0x0c, 0x09, 0x09, 0x0c, "v_mul_hi_u32_u24", False), + ( -1, -1, -1, 0x39, 0x0d, "v_dot4c_i32_i8", False), (0x0d, 0x0d, -1, -1, -1, "v_min_legacy_f32", True), (0x0e, 0x0e, -1, -1, -1, "v_max_legacy_f32", True), (0x0f, 0x0f, 0x0a, 0x0a, 0x0f, "v_min_f32", True), @@ -963,6 +964,10 @@ VOPP = { # (gfx6, gfx7, gfx8, gfx9, gfx10, name) = (-1, -1, -1, code, code, name) for (code, name, modifiers) in VOPP: opcode(name, -1, code, code, Format.VOP3P, InstrClass.Valu32, modifiers, modifiers) +opcode("v_dot2_i32_i16", -1, 0x26, 0x14, Format.VOP3P, InstrClass.Valu32) +opcode("v_dot2_u32_u16", -1, 0x27, 0x15, Format.VOP3P, InstrClass.Valu32) +opcode("v_dot4_i32_i8", -1, 0x28, 0x16, Format.VOP3P, InstrClass.Valu32) +opcode("v_dot4_u32_u8", -1, 0x29, 0x17, Format.VOP3P, InstrClass.Valu32) # VINTERP instructions: diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 8c0c5aabcd1..bcb70eeab48 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -2477,7 +2477,8 @@ register_allocation(Program* program, std::vector& live_out_per_block, ra instr->opcode == aco_opcode::v_mad_f16 || instr->opcode == aco_opcode::v_mad_legacy_f16 || (instr->opcode == aco_opcode::v_fma_f16 && program->chip_class >= GFX10) || - (instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10)) && + (instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10) || + (instr->opcode == aco_opcode::v_dot4_i32_i8 && program->family != CHIP_VEGA20)) && instr->operands[2].isTemp() && instr->operands[2].isKillBeforeDef() && instr->operands[2].getTemp().type() == RegType::vgpr && instr->operands[1].isTemp() && instr->operands[1].getTemp().type() == RegType::vgpr && !instr->usesModifiers() && @@ -2496,6 +2497,7 @@ register_allocation(Program* program, std::vector& live_out_per_block, ra case aco_opcode::v_mad_legacy_f16: instr->opcode = aco_opcode::v_mac_f16; break; case aco_opcode::v_fma_f16: instr->opcode = aco_opcode::v_fmac_f16; break; case aco_opcode::v_pk_fma_f16: instr->opcode = aco_opcode::v_pk_fmac_f16; break; + case aco_opcode::v_dot4_i32_i8: instr->opcode = aco_opcode::v_dot4c_i32_i8; break; default: break; } } @@ -2507,7 +2509,8 @@ register_allocation(Program* program, std::vector& live_out_per_block, ra instr->opcode == aco_opcode::v_mac_f16 || instr->opcode == aco_opcode::v_fmac_f16 || instr->opcode == aco_opcode::v_pk_fmac_f16 || instr->opcode == aco_opcode::v_writelane_b32 || - instr->opcode == aco_opcode::v_writelane_b32_e64) { + instr->opcode == aco_opcode::v_writelane_b32_e64 || + instr->opcode == aco_opcode::v_dot4c_i32_i8) { instr->definitions[0].setFixed(instr->operands[2].physReg()); } else if (instr->opcode == aco_opcode::s_addk_i32 || instr->opcode == aco_opcode::s_mulk_i32) {