mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-19 23:00:45 +01:00
radv: add support for cooperative matrix per element operations.
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36992>
This commit is contained in:
parent
1ba49c3594
commit
3eef0c0245
2 changed files with 76 additions and 0 deletions
|
|
@ -983,6 +983,78 @@ lower_cmat_deref(nir_deref_instr *deref, struct hash_table *type_map, const lowe
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower_cmat_params *params)
|
||||
{
|
||||
nir_def *src = radv_nir_load_cmat(b, params, call->params[3].ssa);
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_function *fnptr = call->callee;
|
||||
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
|
||||
nir_variable *elem_tmp = nir_local_variable_create(b->impl, glsl_get_cmat_element(dst_deref->type), "elemtmp");
|
||||
nir_deref_instr *elem_deref = nir_build_deref_var(b, elem_tmp);
|
||||
nir_def *local_idx = nir_load_subgroup_invocation(b);
|
||||
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
|
||||
unsigned length = radv_nir_cmat_length(desc, params);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc, params);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16;
|
||||
nir_def *base_row = radv_get_base_row(b, desc, params, local_idx);
|
||||
nir_def *vars[16];
|
||||
|
||||
if (mul > 1) {
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
if (i % mul != 0)
|
||||
vars[i] = nir_undef(b, 1, radv_nir_cmat_bits(desc));
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < length / mul; i++) {
|
||||
nir_def *src_elem = nir_channel(b, src, i * mul);
|
||||
nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr);
|
||||
uint32_t row_iter;
|
||||
|
||||
if (params->gfx_level >= GFX12) {
|
||||
row_iter = i;
|
||||
} else {
|
||||
row_iter = i * lanes_per_iter / 16;
|
||||
}
|
||||
|
||||
nir_def *row_val = nir_iadd_imm(b, base_row, row_iter);
|
||||
nir_def *col_val = inner_idx;
|
||||
|
||||
if (desc.use == GLSL_CMAT_USE_A)
|
||||
SWAP(col_val, row_val);
|
||||
|
||||
row_val = nir_iadd(b, call->params[1].ssa, row_val);
|
||||
col_val = nir_iadd(b, call->params[2].ssa, col_val);
|
||||
|
||||
new_call->params[0] = nir_src_for_ssa(&elem_deref->def);
|
||||
new_call->params[1] = nir_src_for_ssa(row_val);
|
||||
new_call->params[2] = nir_src_for_ssa(col_val);
|
||||
new_call->params[3] = nir_src_for_ssa(src_elem);
|
||||
|
||||
for (unsigned p = 4; p < call->num_params; p++) {
|
||||
nir_deref_instr *deref = nir_src_as_deref(call->params[p]);
|
||||
nir_def *def = call->params[p].ssa;
|
||||
if (deref) {
|
||||
if (glsl_type_is_cmat(deref->type)) {
|
||||
def = nir_build_load_deref(b, radv_nir_cmat_length(desc, params), radv_nir_cmat_bits(desc), def);
|
||||
def = nir_channel(b, def, i * mul);
|
||||
}
|
||||
}
|
||||
new_call->params[p] = nir_src_for_ssa(def);
|
||||
}
|
||||
nir_builder_instr_insert(b, &new_call->instr);
|
||||
vars[i * mul] = nir_build_load_deref(b, 1, radv_nir_cmat_bits(desc), &elem_deref->def, 0);
|
||||
}
|
||||
|
||||
nir_def *mat = nir_vec(b, vars, length);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(src->num_components));
|
||||
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
|
||||
{
|
||||
|
|
@ -1084,6 +1156,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
case nir_cmat_call_op_reduce_2x2:
|
||||
progress |= lower_cmat_reduce_2x2_call(&b, call, ¶ms);
|
||||
break;
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
progress |= lower_cmat_per_element_op(&b, call, ¶ms);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1403,6 +1403,7 @@ radv_physical_device_get_features(const struct radv_physical_device *pdev, struc
|
|||
.cooperativeMatrixConversions = true,
|
||||
.cooperativeMatrixFlexibleDimensions = true,
|
||||
.cooperativeMatrixReductions = true,
|
||||
.cooperativeMatrixPerElementOperations = true,
|
||||
|
||||
/* VK_KHR_video_encode_av1 */
|
||||
.videoEncodeAV1 = true,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue