ac/llvm: fix buffer_load_format with TFE by replacing inline asm with LLVM code

It was broken with gfx12 at least. This fixes vkd3d-proton tests with
RADV_DEBUG=llvm and removes the hard-to-maintain inline assembly.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39474>
This commit is contained in:
Marek Olšák 2026-01-23 12:22:58 -05:00 committed by Marge Bot
parent bac80013e6
commit fc9f56556a

View file

@ -927,7 +927,7 @@ static LLVMValueRef ac_build_buffer_load_common(struct ac_llvm_context *ctx, LLV
LLVMValueRef vindex, LLVMValueRef voffset,
LLVMValueRef soffset, unsigned num_channels,
LLVMTypeRef channel_type, enum gl_access_qualifier access,
bool can_speculate, bool use_format)
bool can_speculate, bool use_format, bool tfe)
{
LLVMValueRef args[5];
int idx = 0;
@ -937,17 +937,27 @@ static LLVMValueRef ac_build_buffer_load_common(struct ac_llvm_context *ctx, LLV
args[idx++] = voffset ? voffset : ctx->i32_0;
args[idx++] = soffset ? soffset : ctx->i32_0;
args[idx++] = LLVMConstInt(ctx->i32, get_cache_flags(ctx, access, ac_access_type_load), 0);
unsigned func =
unsigned return_channels =
!ac_has_vec3_support(ctx->gfx_level, use_format) && num_channels == 3 ? 4 : num_channels;
const char *indexing_kind = vindex ? "struct" : "raw";
char name[256], type_name[8];
char name[256], type_name[64];
/* D16 is only supported on gfx8+ */
assert(!use_format || (channel_type != ctx->f16 && channel_type != ctx->i16) ||
ctx->gfx_level >= GFX8);
LLVMTypeRef type = func > 1 ? LLVMVectorType(channel_type, func) : channel_type;
ac_build_type_name_for_intr(type, type_name, sizeof(type_name));
LLVMTypeRef return_type = return_channels > 1 ? LLVMVectorType(channel_type, return_channels) :
channel_type;
if (tfe) {
assert(LLVM_VERSION_MAJOR >= 19); /* not supported by older LLVM */
assert(ac_get_type_size(channel_type) == 2 || ac_get_type_size(channel_type) == 4);
/* If the return type is a structure, the returned data is first, and TFE is second as i32. */
return_type = LLVMStructTypeInContext(ctx->context, (LLVMTypeRef[]){return_type, ctx->i32},
2, false);
}
ac_build_type_name_for_intr(return_type, type_name, sizeof(type_name));
if (use_format) {
snprintf(name, sizeof(name), "llvm.amdgcn.%s.buffer.load.format.%s", indexing_kind,
@ -956,10 +966,32 @@ static LLVMValueRef ac_build_buffer_load_common(struct ac_llvm_context *ctx, LLV
snprintf(name, sizeof(name), "llvm.amdgcn.%s.buffer.load.%s", indexing_kind, type_name);
}
LLVMValueRef result = ac_build_intrinsic(ctx, name, type, args, idx,
LLVMValueRef result = ac_build_intrinsic(ctx, name, return_type, args, idx,
can_speculate ? AC_ATTR_INVARIANT_LOAD : 0);
if (func > num_channels)
result = ac_trim_vector(ctx, result, num_channels);
if (tfe) {
/* Get TFE and the result out of the structure. */
LLVMValueRef tfe_value = LLVMBuildExtractValue(ctx->builder, result, 1, "");
/* If the return type has 16 bits, truncate the 32-bit TFE result. */
if (ac_get_type_size(channel_type) == 16)
tfe_value = LLVMBuildTrunc(ctx->builder, tfe_value, ctx->i16, "");
tfe_value = LLVMBuildBitCast(ctx->builder, tfe_value, channel_type, "");
result = LLVMBuildExtractValue(ctx->builder, result, 0, "");
if (return_channels > num_channels)
result = ac_trim_vector(ctx, result, num_channels);
result = ac_build_expand(ctx, result, num_channels, num_channels + 1);
result = LLVMBuildInsertElement(ctx->builder, result, tfe_value,
LLVMConstInt(ctx->i32, num_channels, 0), "");
} else {
if (return_channels > num_channels)
result = ac_trim_vector(ctx, result, num_channels);
}
return result;
}
@ -1011,7 +1043,7 @@ LLVMValueRef ac_build_buffer_load(struct ac_llvm_context *ctx, LLVMValueRef rsrc
LLVMConstInt(ctx->i32, i * ac_get_type_size(channel_type), 0), "");
LLVMValueRef item =
ac_build_buffer_load_common(ctx, rsrc, vindex, fetch_voffset, soffset, fetch_num_channels,
channel_type, access, can_speculate, false);
channel_type, access, can_speculate, false, false);
result = ac_build_concat(ctx, result, item);
}
@ -1023,81 +1055,9 @@ LLVMValueRef ac_build_buffer_load_format(struct ac_llvm_context *ctx, LLVMValueR
unsigned num_channels, enum gl_access_qualifier access,
bool can_speculate, bool d16, bool tfe)
{
if (tfe) {
assert(!d16);
union ac_hw_cache_flags cache_flags =
ac_get_hw_cache_flags(ctx->gfx_level, access, ac_access_type_load);
char code[1024];
/* The definition in the assembly and the one in the constraint string
* differs because of an assembler bug.
*/
if (ctx->gfx_level >= GFX12) {
const char *scope = "";
const char *temporal_hint = "";
if (cache_flags.gfx12.scope == gfx12_scope_se)
scope = "scope:SCOPE_SE";
else if (cache_flags.gfx12.scope == gfx12_scope_device)
scope = "scope:SCOPE_DEV";
else if (cache_flags.gfx12.scope == gfx12_scope_memory)
scope = "scope:SCOPE_SYS";
if (cache_flags.gfx12.temporal_hint == gfx12_load_non_temporal)
temporal_hint = "th:TH_LOAD_NT";
else if (cache_flags.gfx12.temporal_hint == gfx12_load_high_temporal)
temporal_hint = "th:TH_LOAD_HT";
else if (cache_flags.gfx12.temporal_hint == gfx12_load_last_use_discard)
temporal_hint = "th:TH_LOAD_LU";
else if (cache_flags.gfx12.temporal_hint == gfx12_load_near_non_temporal_far_regular_temporal)
temporal_hint = "th:TH_LOAD_NT_RT";
else if (cache_flags.gfx12.temporal_hint == gfx12_load_near_regular_temporal_far_non_temporal)
temporal_hint = "th:TH_LOAD_RT_NT";
else if (cache_flags.gfx12.temporal_hint == gfx12_load_near_non_temporal_far_high_temporal)
temporal_hint = "th:TH_LOAD_NT_HT";
snprintf(code, sizeof(code),
"v_mov_b32 v0, 0\n"
"v_mov_b32 v1, 0\n"
"v_mov_b32 v2, 0\n"
"v_mov_b32 v3, 0\n"
"v_mov_b32 v4, 0\n"
"buffer_load_format_xyzw v[0:3], $1, $2, 0, idxen offen %s %s tfe\n"
"s_waitcnt vmcnt(0)",
temporal_hint, scope);
} else {
snprintf(code, sizeof(code),
"v_mov_b32 v0, 0\n"
"v_mov_b32 v1, 0\n"
"v_mov_b32 v2, 0\n"
"v_mov_b32 v3, 0\n"
"v_mov_b32 v4, 0\n"
"buffer_load_format_xyzw v[0:3], $1, $2, 0, idxen offen %s %s tfe %s\n"
"s_waitcnt vmcnt(0)",
cache_flags.value & ac_glc ? "glc" : "",
cache_flags.value & ac_slc ? "slc" : "",
cache_flags.value & ac_dlc ? "dlc" : "");
}
LLVMTypeRef param_types[] = {ctx->v2i32, ctx->v4i32};
LLVMTypeRef calltype = LLVMFunctionType(LLVMVectorType(ctx->f32, 5), param_types, 2, false);
LLVMValueRef inlineasm = LLVMConstInlineAsm(calltype, code, "=&{v[0:4]},v,s", false, false);
LLVMValueRef addr_comp[2] = {vindex ? vindex : ctx->i32_0,
voffset ? voffset : ctx->i32_0};
LLVMValueRef args[] = {ac_build_gather_values(ctx, addr_comp, 2),
LLVMBuildBitCast(ctx->builder, rsrc, ctx->v4i32, "")};
LLVMValueRef res = LLVMBuildCall2(ctx->builder, calltype, inlineasm, args, 2, "");
return ac_build_concat(ctx, ac_trim_vector(ctx, res, num_channels),
ac_llvm_extract_elem(ctx, res, 4));
}
return ac_build_buffer_load_common(ctx, rsrc, vindex, voffset, ctx->i32_0,
num_channels, d16 ? ctx->f16 : ctx->f32, access,
can_speculate, true);
can_speculate, true, tfe);
}
static LLVMValueRef ac_build_tbuffer_load(struct ac_llvm_context *ctx, LLVMValueRef rsrc,
@ -1207,7 +1167,7 @@ LLVMValueRef ac_build_buffer_load_short(struct ac_llvm_context *ctx, LLVMValueRe
enum gl_access_qualifier access)
{
return ac_build_buffer_load_common(ctx, rsrc, NULL, voffset, soffset, 1, ctx->i16,
access, false, false);
access, false, false, false);
}
LLVMValueRef ac_build_buffer_load_byte(struct ac_llvm_context *ctx, LLVMValueRef rsrc,
@ -1215,7 +1175,7 @@ LLVMValueRef ac_build_buffer_load_byte(struct ac_llvm_context *ctx, LLVMValueRef
enum gl_access_qualifier access)
{
return ac_build_buffer_load_common(ctx, rsrc, NULL, voffset, soffset, 1, ctx->i8, access,
false, false);
false, false, false);
}
void ac_build_buffer_store_short(struct ac_llvm_context *ctx, LLVMValueRef rsrc,