diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index f5245645eb5..cf336cdb0ba 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -8030,6 +8030,12 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break; } break; + case GLSL_TYPE_BFLOAT16: + switch (instr->def.bit_size) { + case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_bf16; break; + case 16: opcode = aco_opcode::v_wmma_bf16_16x16x16_bf16; break; + } + break; case GLSL_TYPE_UINT8: case GLSL_TYPE_INT8: { opcode = aco_opcode::v_wmma_i32_16x16x16_iu8; @@ -8053,13 +8059,18 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) uint32_t constant; uint32_t acc_stride = ctx->program->gfx_level < GFX12 && instr->def.bit_size == 16 ? 2 : 1; if (get_replicated_constant(instr->src[2].ssa, acc_stride, &constant)) { - Operand constC = - Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8); + unsigned constant_size = instr->def.bit_size; + if (opcode == aco_opcode::v_wmma_bf16_16x16x16_bf16) { + /* Bfloat16 uses the high bits of 32bit inline constants. */ + constant <<= 16; + constant_size = 32; + } + Operand constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8); if (!constC.isLiteral()) { C = constC; } else if (opcode != aco_opcode::v_wmma_i32_16x16x16_iu8) { - constant ^= 1 << (instr->def.bit_size - 1); - constC = Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8); + constant ^= 1 << (constant_size - 1); + constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8); if (!constC.isLiteral()) { C = constC; neg_lo[2] ^= !neg_hi[2];