mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-09 04:38:03 +02:00
aco: support bf16 wmma
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34768>
This commit is contained in:
parent
e8f5c335ff
commit
5ca98bf99e
1 changed files with 15 additions and 4 deletions
|
|
@ -8030,6 +8030,12 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||||
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
|
case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GLSL_TYPE_BFLOAT16:
|
||||||
|
switch (instr->def.bit_size) {
|
||||||
|
case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_bf16; break;
|
||||||
|
case 16: opcode = aco_opcode::v_wmma_bf16_16x16x16_bf16; break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GLSL_TYPE_UINT8:
|
case GLSL_TYPE_UINT8:
|
||||||
case GLSL_TYPE_INT8: {
|
case GLSL_TYPE_INT8: {
|
||||||
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
|
opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
|
||||||
|
|
@ -8053,13 +8059,18 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
|
||||||
uint32_t constant;
|
uint32_t constant;
|
||||||
uint32_t acc_stride = ctx->program->gfx_level < GFX12 && instr->def.bit_size == 16 ? 2 : 1;
|
uint32_t acc_stride = ctx->program->gfx_level < GFX12 && instr->def.bit_size == 16 ? 2 : 1;
|
||||||
if (get_replicated_constant(instr->src[2].ssa, acc_stride, &constant)) {
|
if (get_replicated_constant(instr->src[2].ssa, acc_stride, &constant)) {
|
||||||
Operand constC =
|
unsigned constant_size = instr->def.bit_size;
|
||||||
Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
|
if (opcode == aco_opcode::v_wmma_bf16_16x16x16_bf16) {
|
||||||
|
/* Bfloat16 uses the high bits of 32bit inline constants. */
|
||||||
|
constant <<= 16;
|
||||||
|
constant_size = 32;
|
||||||
|
}
|
||||||
|
Operand constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
|
||||||
if (!constC.isLiteral()) {
|
if (!constC.isLiteral()) {
|
||||||
C = constC;
|
C = constC;
|
||||||
} else if (opcode != aco_opcode::v_wmma_i32_16x16x16_iu8) {
|
} else if (opcode != aco_opcode::v_wmma_i32_16x16x16_iu8) {
|
||||||
constant ^= 1 << (instr->def.bit_size - 1);
|
constant ^= 1 << (constant_size - 1);
|
||||||
constC = Operand::get_const(ctx->program->gfx_level, constant, instr->def.bit_size / 8);
|
constC = Operand::get_const(ctx->program->gfx_level, constant, constant_size / 8);
|
||||||
if (!constC.isLiteral()) {
|
if (!constC.isLiteral()) {
|
||||||
C = constC;
|
C = constC;
|
||||||
neg_lo[2] ^= !neg_hi[2];
|
neg_lo[2] ^= !neg_hi[2];
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue