ac/llvm: implement mixed float dot

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40003>
This commit is contained in:
Georg Lehmann 2026-02-18 13:21:14 +01:00 committed by Marge Bot
parent dd067088ef
commit 892274d20d

View file

@ -1153,6 +1153,21 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
break;
}
case nir_op_bfdot2_fadd: {
const char *name = "llvm.amdgcn.fdot2.f32.bf16";
LLVMTypeRef vec2_type = ctx->ac.v2bf16;
#if LLVM_VERSION_MAJOR < 19 || (LLVM_VERSION_MAJOR == 19 && LLVM_VERSION_MINOR == 0)
/* Before LLVM 19.1, bf16 fdot used integer operands. */
vec2_type = ctx->ac.v2i16;
#endif
src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], vec2_type, "");
src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], vec2_type, "");
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
src[3] = ctx->ac.i1false; /* clamp */
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0);
break;
}
case nir_op_bfdot2_bfadd: {
const char *name = "llvm.amdgcn.fdot2.bf16.bf16";
LLVMTypeRef vec2_type = ctx->ac.v2bf16;
@ -1169,6 +1184,44 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
break;
}
case nir_op_f16dot2_fadd: {
src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], ctx->ac.v2f16, "");
src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], ctx->ac.v2f16, "");
if (instr->def.bit_size == 16) {
const char *name = "llvm.amdgcn.fdot2.f16.f16";
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f16, "");
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f16, src, 3, 0);
} else {
const char *name = "llvm.amdgcn.fdot2";
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
src[3] = ctx->ac.i1false; /* clamp */
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0);
}
break;
}
case nir_op_e4m3fn_dot4_fadd: {
const char *name = "llvm.amdgcn.dot4.f32.fp8.fp8";
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
break;
}
case nir_op_e5m2_dot4_fadd: {
const char *name = "llvm.amdgcn.dot4.f32.bf8.bf8";
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
break;
}
case nir_op_e4m3fn_e5m2_dot4_fadd: {
const char *name = "llvm.amdgcn.dot4.f32.fp8.bf8";
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
break;
}
case nir_op_msad_4x8:
result = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.msad.u8", ctx->ac.i32,
(LLVMValueRef[]){src[1], src[0], src[2]}, 3, 0);