mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-17 12:30:33 +01:00
ac/llvm: implement mixed float dot
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40003>
This commit is contained in:
parent
dd067088ef
commit
892274d20d
1 changed files with 53 additions and 0 deletions
|
|
@ -1153,6 +1153,21 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
|
|||
break;
|
||||
}
|
||||
|
||||
case nir_op_bfdot2_fadd: {
|
||||
const char *name = "llvm.amdgcn.fdot2.f32.bf16";
|
||||
LLVMTypeRef vec2_type = ctx->ac.v2bf16;
|
||||
#if LLVM_VERSION_MAJOR < 19 || (LLVM_VERSION_MAJOR == 19 && LLVM_VERSION_MINOR == 0)
|
||||
/* Before LLVM 19.1, bf16 fdot used integer operands. */
|
||||
vec2_type = ctx->ac.v2i16;
|
||||
#endif
|
||||
src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], vec2_type, "");
|
||||
src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], vec2_type, "");
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
|
||||
src[3] = ctx->ac.i1false; /* clamp */
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_bfdot2_bfadd: {
|
||||
const char *name = "llvm.amdgcn.fdot2.bf16.bf16";
|
||||
LLVMTypeRef vec2_type = ctx->ac.v2bf16;
|
||||
|
|
@ -1169,6 +1184,44 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
|
|||
break;
|
||||
}
|
||||
|
||||
case nir_op_f16dot2_fadd: {
|
||||
src[0] = LLVMBuildBitCast(ctx->ac.builder, src[0], ctx->ac.v2f16, "");
|
||||
src[1] = LLVMBuildBitCast(ctx->ac.builder, src[1], ctx->ac.v2f16, "");
|
||||
|
||||
if (instr->def.bit_size == 16) {
|
||||
const char *name = "llvm.amdgcn.fdot2.f16.f16";
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f16, "");
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f16, src, 3, 0);
|
||||
} else {
|
||||
const char *name = "llvm.amdgcn.fdot2";
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
|
||||
src[3] = ctx->ac.i1false; /* clamp */
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 4, 0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_e4m3fn_dot4_fadd: {
|
||||
const char *name = "llvm.amdgcn.dot4.f32.fp8.fp8";
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_e5m2_dot4_fadd: {
|
||||
const char *name = "llvm.amdgcn.dot4.f32.bf8.bf8";
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_e4m3fn_e5m2_dot4_fadd: {
|
||||
const char *name = "llvm.amdgcn.dot4.f32.fp8.bf8";
|
||||
src[2] = LLVMBuildBitCast(ctx->ac.builder, src[2], ctx->ac.f32, "");
|
||||
result = ac_build_intrinsic(&ctx->ac, name, ctx->ac.f32, src, 3, 0);
|
||||
break;
|
||||
}
|
||||
|
||||
case nir_op_msad_4x8:
|
||||
result = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.msad.u8", ctx->ac.i32,
|
||||
(LLVMValueRef[]){src[1], src[0], src[2]}, 3, 0);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue