aco: implement udot_4x8/sdot_4x8/udot_2x16/sdot_2x16 opcodes

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>
This commit is contained in:
Rhys Perry 2021-08-02 19:42:44 +01:00 committed by Marge Bot
parent e0d232c2fc
commit 2a7fa132be
4 changed files with 69 additions and 3 deletions

View file

@ -960,6 +960,24 @@ emit_vop3p_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, T
return res;
}
void
emit_idot_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst, bool clamp)
{
Temp src[3] = {Temp(0, v1), Temp(0, v1), Temp(0, v1)};
bool has_sgpr = false;
for (unsigned i = 0; i < 3; i++) {
src[i] = get_alu_src(ctx, instr->src[i]);
if (has_sgpr)
src[i] = as_vgpr(ctx, src[i]);
else
has_sgpr = src[i].type() == RegType::sgpr;
}
Builder bld(ctx->program, ctx->block);
bld.is_precise = instr->exact;
bld.vop3p(op, Definition(dst), src[0], src[1], src[2], 0x0, 0x7).instr->vop3p().clamp = clamp;
}
void
emit_vop1_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst)
{
@ -2112,6 +2130,38 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
}
break;
}
case nir_op_sdot_4x8_iadd: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, false);
break;
}
case nir_op_sdot_4x8_iadd_sat: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, true);
break;
}
case nir_op_udot_4x8_uadd: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, false);
break;
}
case nir_op_udot_4x8_uadd_sat: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, true);
break;
}
case nir_op_sdot_2x16_iadd: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, false);
break;
}
case nir_op_sdot_2x16_iadd_sat: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, true);
break;
}
case nir_op_udot_2x16_uadd: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, false);
break;
}
case nir_op_udot_2x16_uadd_sat: {
emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, true);
break;
}
case nir_op_cube_face_coord_amd: {
Temp in = get_alu_src(ctx, instr->src[0], 3);
Temp src[3] = {emit_extract_vector(ctx, in, 0, v1), emit_extract_vector(ctx, in, 1, v1),

View file

@ -594,7 +594,15 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_cube_face_index_amd:
case nir_op_cube_face_coord_amd:
case nir_op_sad_u8x4:
case nir_op_iadd_sat: type = RegType::vgpr; break;
case nir_op_iadd_sat:
case nir_op_udot_4x8_uadd:
case nir_op_sdot_4x8_iadd:
case nir_op_udot_4x8_uadd_sat:
case nir_op_sdot_4x8_iadd_sat:
case nir_op_udot_2x16_uadd:
case nir_op_sdot_2x16_iadd:
case nir_op_udot_2x16_uadd_sat:
case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break;
case nir_op_f2i16:
case nir_op_f2u16:
case nir_op_f2i32:

View file

@ -680,6 +680,7 @@ VOP2 = {
(0x0a, 0x0a, 0x07, 0x07, 0x0a, "v_mul_hi_i32_i24", False),
(0x0b, 0x0b, 0x08, 0x08, 0x0b, "v_mul_u32_u24", False),
(0x0c, 0x0c, 0x09, 0x09, 0x0c, "v_mul_hi_u32_u24", False),
( -1, -1, -1, 0x39, 0x0d, "v_dot4c_i32_i8", False),
(0x0d, 0x0d, -1, -1, -1, "v_min_legacy_f32", True),
(0x0e, 0x0e, -1, -1, -1, "v_max_legacy_f32", True),
(0x0f, 0x0f, 0x0a, 0x0a, 0x0f, "v_min_f32", True),
@ -963,6 +964,10 @@ VOPP = {
# (gfx6, gfx7, gfx8, gfx9, gfx10, name) = (-1, -1, -1, code, code, name)
for (code, name, modifiers) in VOPP:
opcode(name, -1, code, code, Format.VOP3P, InstrClass.Valu32, modifiers, modifiers)
opcode("v_dot2_i32_i16", -1, 0x26, 0x14, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot2_u32_u16", -1, 0x27, 0x15, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot4_i32_i8", -1, 0x28, 0x16, Format.VOP3P, InstrClass.Valu32)
opcode("v_dot4_u32_u8", -1, 0x29, 0x17, Format.VOP3P, InstrClass.Valu32)
# VINTERP instructions:

View file

@ -2477,7 +2477,8 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
instr->opcode == aco_opcode::v_mad_f16 ||
instr->opcode == aco_opcode::v_mad_legacy_f16 ||
(instr->opcode == aco_opcode::v_fma_f16 && program->chip_class >= GFX10) ||
(instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10)) &&
(instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10) ||
(instr->opcode == aco_opcode::v_dot4_i32_i8 && program->family != CHIP_VEGA20)) &&
instr->operands[2].isTemp() && instr->operands[2].isKillBeforeDef() &&
instr->operands[2].getTemp().type() == RegType::vgpr && instr->operands[1].isTemp() &&
instr->operands[1].getTemp().type() == RegType::vgpr && !instr->usesModifiers() &&
@ -2496,6 +2497,7 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
case aco_opcode::v_mad_legacy_f16: instr->opcode = aco_opcode::v_mac_f16; break;
case aco_opcode::v_fma_f16: instr->opcode = aco_opcode::v_fmac_f16; break;
case aco_opcode::v_pk_fma_f16: instr->opcode = aco_opcode::v_pk_fmac_f16; break;
case aco_opcode::v_dot4_i32_i8: instr->opcode = aco_opcode::v_dot4c_i32_i8; break;
default: break;
}
}
@ -2507,7 +2509,8 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
instr->opcode == aco_opcode::v_mac_f16 || instr->opcode == aco_opcode::v_fmac_f16 ||
instr->opcode == aco_opcode::v_pk_fmac_f16 ||
instr->opcode == aco_opcode::v_writelane_b32 ||
instr->opcode == aco_opcode::v_writelane_b32_e64) {
instr->opcode == aco_opcode::v_writelane_b32_e64 ||
instr->opcode == aco_opcode::v_dot4c_i32_i8) {
instr->definitions[0].setFixed(instr->operands[2].physReg());
} else if (instr->opcode == aco_opcode::s_addk_i32 ||
instr->opcode == aco_opcode::s_mulk_i32) {