radv: add support for cooperative matrix per element operations.
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36992>
This commit is contained in:
Dave Airlie 2025-08-18 11:32:40 +10:00
parent 1ba49c3594
commit 3eef0c0245
2 changed files with 76 additions and 0 deletions

View file

@ -983,6 +983,78 @@ lower_cmat_deref(nir_deref_instr *deref, struct hash_table *type_map, const lowe
return false;
}
static bool
lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower_cmat_params *params)
{
nir_def *src = radv_nir_load_cmat(b, params, call->params[3].ssa);
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
nir_function *fnptr = call->callee;
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_variable *elem_tmp = nir_local_variable_create(b->impl, glsl_get_cmat_element(dst_deref->type), "elemtmp");
nir_deref_instr *elem_deref = nir_build_deref_var(b, elem_tmp);
nir_def *local_idx = nir_load_subgroup_invocation(b);
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
unsigned length = radv_nir_cmat_length(desc, params);
unsigned mul = radv_nir_cmat_length_mul(desc, params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16;
nir_def *base_row = radv_get_base_row(b, desc, params, local_idx);
nir_def *vars[16];
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
if (i % mul != 0)
vars[i] = nir_undef(b, 1, radv_nir_cmat_bits(desc));
}
for (unsigned i = 0; i < length / mul; i++) {
nir_def *src_elem = nir_channel(b, src, i * mul);
nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr);
uint32_t row_iter;
if (params->gfx_level >= GFX12) {
row_iter = i;
} else {
row_iter = i * lanes_per_iter / 16;
}
nir_def *row_val = nir_iadd_imm(b, base_row, row_iter);
nir_def *col_val = inner_idx;
if (desc.use == GLSL_CMAT_USE_A)
SWAP(col_val, row_val);
row_val = nir_iadd(b, call->params[1].ssa, row_val);
col_val = nir_iadd(b, call->params[2].ssa, col_val);
new_call->params[0] = nir_src_for_ssa(&elem_deref->def);
new_call->params[1] = nir_src_for_ssa(row_val);
new_call->params[2] = nir_src_for_ssa(col_val);
new_call->params[3] = nir_src_for_ssa(src_elem);
for (unsigned p = 4; p < call->num_params; p++) {
nir_deref_instr *deref = nir_src_as_deref(call->params[p]);
nir_def *def = call->params[p].ssa;
if (deref) {
if (glsl_type_is_cmat(deref->type)) {
def = nir_build_load_deref(b, radv_nir_cmat_length(desc, params), radv_nir_cmat_bits(desc), def);
def = nir_channel(b, def, i * mul);
}
}
new_call->params[p] = nir_src_for_ssa(def);
}
nir_builder_instr_insert(b, &new_call->instr);
vars[i * mul] = nir_build_load_deref(b, 1, radv_nir_cmat_bits(desc), &elem_deref->def, 0);
}
nir_def *mat = nir_vec(b, vars, length);
nir_store_deref(b, dst_deref, mat, nir_component_mask(src->num_components));
nir_instr_remove(&call->instr);
return true;
}
bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
{
@ -1084,6 +1156,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
case nir_cmat_call_op_reduce_2x2:
progress |= lower_cmat_reduce_2x2_call(&b, call, &params);
break;
case nir_cmat_call_op_per_element_op:
progress |= lower_cmat_per_element_op(&b, call, &params);
break;
default:
break;
}

View file

@ -1403,6 +1403,7 @@ radv_physical_device_get_features(const struct radv_physical_device *pdev, struc
.cooperativeMatrixConversions = true,
.cooperativeMatrixFlexibleDimensions = true,
.cooperativeMatrixReductions = true,
.cooperativeMatrixPerElementOperations = true,
/* VK_KHR_video_encode_av1 */
.videoEncodeAV1 = true,