From 892274d20d5f4d4d1659a2eb6039adecfedc80ee Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Wed, 18 Feb 2026 13:21:14 +0100 Subject: [PATCH] ac/llvm: implement mixed float dot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewed-by: Daniel Schürmann Part-of: --- src/amd/llvm/ac_nir_to_llvm.c | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c index c9f1bcfd660..56ae0c47dff 100644 --- a/src/amd/llvm/ac_nir_to_llvm.c +++ b/src/amd/llvm/ac_nir_to_llvm.c @@ -1153,6 +1153,21 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) break; } + case nir_op_bfdot2_fadd: { + const char *name = "llvm.amdgcn.fdot2.f32.bf16"; + LLVMTypeRef vec2_type = ctx->ac.v2bf16; +#if LLVM_VERSION_MAJOR < 19 || (LLVM_VERSION_MAJOR == 19 && LLVM_VERSION_MINOR == 0) + /* Before LLVM 19.1, bf16 fdot used integer operands. */ + vec2_type = ctx->ac.v2i16; +#endif + src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], vec2_type, ""); + src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], vec2_type, ""); + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, ""); + src[3] = ctx->ac.i1false; /* clamp */ + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0); + break; + } + case nir_op_bfdot2_bfadd: { const char *name = "llvm.amdgcn.fdot2.bf16.bf16"; LLVMTypeRef vec2_type = ctx->ac.v2bf16; @@ -1169,6 +1184,44 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) break; } + case nir_op_f16dot2_fadd: { + src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], ctx->ac.v2f16, ""); + src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], ctx->ac.v2f16, ""); + + if (instr->def.bit_size == 16) { + const char *name = "llvm.amdgcn.fdot2.f16.f16"; + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f16, ""); + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f16, src, 3, 0); + } else { + const char *name = "llvm.amdgcn.fdot2"; + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, ""); + src[3] = ctx->ac.i1false; /* clamp */ + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0); + } + break; + } + + case nir_op_e4m3fn_dot4_fadd: { + const char *name = "llvm.amdgcn.dot4.f32.fp8.fp8"; + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, ""); + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0); + break; + } + + case nir_op_e5m2_dot4_fadd: { + const char *name = "llvm.amdgcn.dot4.f32.bf8.bf8"; + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, ""); + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0); + break; + } + + case nir_op_e4m3fn_e5m2_dot4_fadd: { + const char *name = "llvm.amdgcn.dot4.f32.fp8.bf8"; + src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, ""); + result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0); + break; + } + case nir_op_msad_4x8: result = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.msad.u8", ctx->ac.i32, (LLVMValueRef[]){src[1], src[0], src[2]}, 3, 0);