mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-25 21:40:08 +01:00
aco: support bf16 wmma
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34768>
This commit is contained in:
parent
e8f5c335ff
commit
5ca98bf99e
1 changed files with 15 additions and 4 deletions
|
|
@ -8030,6 +8030,12 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
|||
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
|
||||
}
|
||||
break;
|
||||
case GLSL_TYPE_BFLOAT16:
|
||||
switch (instr->def.bit_size) {
|
||||
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_bf16; break;
|
||||
case 16: opcode = aco_opcode::v_wmma_bf16_16x16x16_bf16; break;
|
||||
}
|
||||
break;
|
||||
case GLSL_TYPE_UINT8:
|
||||
case GLSL_TYPE_INT8: {
|
||||
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
|
||||
|
|
@ -8053,13 +8059,18 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
|||
uint32_t constant;
|
||||
uint32_t acc_stride = ctx->program->gfx_level < GFX12 && instr->def.bit_size == 16 ? 2 : 1;
|
||||
if (get_replicated_constant(instr->src[2].ssa, acc_stride, &constant)) {
|
||||
Operand constC =
|
||||
Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
|
||||
unsigned constant_size = instr->def.bit_size;
|
||||
if (opcode == aco_opcode::v_wmma_bf16_16x16x16_bf16) {
|
||||
/* Bfloat16 uses the high bits of 32bit inline constants. */
|
||||
constant <<= 16;
|
||||
constant_size = 32;
|
||||
}
|
||||
Operand constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
|
||||
if (!constC.isLiteral()) {
|
||||
C = constC;
|
||||
} else if (opcode != aco_opcode::v_wmma_i32_16x16x16_iu8) {
|
||||
constant ^= 1 << (instr->def.bit_size - 1);
|
||||
constC = Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
|
||||
constant ^= 1 << (constant_size - 1);
|
||||
constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
|
||||
if (!constC.isLiteral()) {
|
||||
C = constC;
|
||||
neg_lo[2] ^= !neg_hi[2];
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue