aco: support bf16 wmma

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34768>
This commit is contained in:
Georg Lehmann 2025-04-30 11:29:27 +02:00 committed by Marge Bot
parent e8f5c335ff
commit 5ca98bf99e

View file

@ -8030,6 +8030,12 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
}
break;
case GLSL_TYPE_BFLOAT16:
switch (instr->def.bit_size) {
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_bf16; break;
case 16: opcode = aco_opcode::v_wmma_bf16_16x16x16_bf16; break;
}
break;
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8: {
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
@ -8053,13 +8059,18 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
uint32_t constant;
uint32_t acc_stride = ctx->program->gfx_level < GFX12 && instr->def.bit_size == 16 ? 2 : 1;
if (get_replicated_constant(instr->src[2].ssa, acc_stride, &constant)) {
Operand constC =
Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
unsigned constant_size = instr->def.bit_size;
if (opcode == aco_opcode::v_wmma_bf16_16x16x16_bf16) {
/* Bfloat16 uses the high bits of 32bit inline constants. */
constant <<= 16;
constant_size = 32;
}
Operand constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
if (!constC.isLiteral()) {
C = constC;
} else if (opcode != aco_opcode::v_wmma_i32_16x16x16_iu8) {
constant ^= 1 << (instr->def.bit_size - 1);
constC = Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
constant ^= 1 << (constant_size - 1);
constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
if (!constC.isLiteral()) {
C = constC;
neg_lo[2] ^= !neg_hi[2];