brw: Add lowering for nir_cmat_call_op_per_element_op

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39904>
This commit is contained in:
Caio Oliveira 2026-02-13 20:56:57 -08:00 committed by Marge Bot
parent 095c470d25
commit ffc3219d57

View file

@ -129,6 +129,9 @@ lower_cmat_filter(const nir_instr *instr, const void *_state)
return glsl_type_is_cmat(deref->type);
}
if (instr->type == nir_instr_type_cmat_call)
return true;
if (instr->type != nir_instr_type_intrinsic)
return false;
@ -574,6 +577,150 @@ lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
nir_component_mask(num_components));
}
static void
lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call,
struct lower_cmat_state *state)
{
nir_deref_instr *dst_slice = nir_src_as_deref(call->params[0]);
nir_deref_instr *src_slice = nir_src_as_deref(call->params[3]);
const slice_info *dst_info = get_slice_info(state, dst_slice);
const slice_info *src_info = get_slice_info(state, src_slice);
assert(cmat_descriptions_are_equal(dst_info->desc, src_info->desc));
nir_def *row_offset = call->params[1].ssa;
nir_def *col_offset = call->params[2].ssa;
struct extra_param {
nir_def *def;
const slice_info *slice_info;
};
const unsigned extra_params_count = call->num_params - 4;
struct extra_param *extra_params =
rzalloc_array(state->temp_ctx, struct extra_param, extra_params_count);
for (unsigned p = 0; p < extra_params_count; p++) {
const nir_src param_src = call->params[4 + p];
struct extra_param *extra = &extra_params[p];
extra->def = param_src.ssa;
nir_deref_instr *deref = nir_src_as_deref(param_src);
if (deref) {
nir_variable *var = nir_deref_instr_get_variable(deref);
struct hash_entry *entry =
_mesa_hash_table_search(state->slice_var_to_slice_info, var);
if (entry) {
extra->slice_info = entry->data;
extra->def = nir_load_deref(b, deref);
assert(cmat_descriptions_are_equal(src_info->desc, extra->slice_info->desc));
}
}
}
const struct glsl_cmat_description desc = dst_info->desc;
const unsigned bits = glsl_base_type_bit_size(desc.element_type);
const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
const unsigned packing_factor = dst_info->packing_factor;
nir_def *src = nir_load_deref(b, src_slice);
nir_def *invocation = nir_load_subgroup_invocation(b);
const unsigned cols = desc.use != GLSL_CMAT_USE_B ? desc.cols / packing_factor : desc.cols;
const unsigned row_step = state->subgroup_size / cols;
nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols);
nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols);
nir_deref_instr *elem_deref =
nir_build_deref_var(b, nir_local_variable_create(b->impl, glsl_scalar_type(desc.element_type),
"cmat_per_element_tmp"));
nir_deref_instr *result_deref =
nir_build_deref_var(b, nir_local_variable_create(b->impl, dst_slice->type,
"cmat_per_element_result"));
nir_store_deref(b, result_deref,
nir_undef(b, num_components, glsl_get_bit_size(dst_slice->type)),
nir_component_mask(num_components));
nir_deref_instr *iter_deref =
nir_build_deref_var(b, nir_local_variable_create(b->impl, glsl_uint_type(),
"cmat_per_element_iter"));
nir_store_deref(b, iter_deref, nir_imm_int(b, 0), 0x1);
nir_loop *loop = nir_push_loop(b);
{
nir_def *iter = nir_load_deref(b, iter_deref);
nir_break_if(b, nir_uge_imm(b, iter, num_components));
nir_def *row_group =
nir_iadd(b, invocation_div_cols, nir_imul_imm(b, iter, row_step));
nir_def *row_base = row_group;
nir_def *col_base = invocation_mod_cols;
if (desc.use == GLSL_CMAT_USE_B)
row_base = nir_imul_imm(b, row_base, packing_factor);
else
col_base = nir_imul_imm(b, col_base, packing_factor);
nir_def *packed = nir_vector_extract(b, src, iter);
nir_def *unpacked = packing_factor > 1 ? nir_unpack_bits(b, packed, bits) : packed;
nir_def *vals[BRW_MAX_PACKING_FACTOR] = {0};
for (unsigned j = 0; j < packing_factor; j++) {
nir_def *row_val = nir_iadd(b, row_base, row_offset);
nir_def *col_val = nir_iadd(b, col_base, col_offset);
if (desc.use == GLSL_CMAT_USE_B)
row_val = nir_iadd_imm(b, row_val, j);
else
col_val = nir_iadd_imm(b, col_val, j);
nir_def *src_elem = packing_factor > 1 ? nir_channel(b, unpacked, j) : unpacked;
nir_call_instr *new_call = nir_call_instr_create(b->shader, call->callee);
new_call->params[0] = nir_src_for_ssa(&elem_deref->def);
new_call->params[1] = nir_src_for_ssa(row_val);
new_call->params[2] = nir_src_for_ssa(col_val);
new_call->params[3] = nir_src_for_ssa(src_elem);
for (unsigned p = 0; p < extra_params_count; p++) {
nir_def *def = extra_params[p].def;
/* Any additional cooperative matrix operands have the corresponding
* matrix element passed to the function.
*/
const slice_info *info = extra_params[p].slice_info;
if (info) {
def = nir_vector_extract(b, def, iter);
if (info->packing_factor > 1) {
const unsigned param_bit_size =
glsl_base_type_bit_size(info->desc.element_type);
def = nir_channel(b, nir_unpack_bits(b, def, param_bit_size), j);
}
}
new_call->params[4 + p] = nir_src_for_ssa(def);
}
nir_builder_instr_insert(b, &new_call->instr);
vals[j] = nir_load_deref(b, elem_deref);
}
packed =
packing_factor > 1 ? nir_pack_bits(b, nir_vec(b, vals, packing_factor),
packing_factor * bits)
: vals[0];
nir_def *new_vec = nir_vector_insert(b, nir_load_deref(b, result_deref),
packed, iter);
nir_store_deref(b, result_deref, new_vec, nir_component_mask(num_components));
nir_store_deref(b, iter_deref, nir_iadd_imm(b, iter, 1), 0x1);
}
nir_pop_loop(b, loop);
nir_store_deref(b, dst_slice, nir_load_deref(b, result_deref),
nir_component_mask(num_components));
}
static nir_deref_instr *
lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
struct lower_cmat_state *state)
@ -605,6 +752,17 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
return &deref->def;
}
if (instr->type == nir_instr_type_cmat_call) {
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
switch (call->op) {
case nir_cmat_call_op_per_element_op:
lower_cmat_per_element_op(b, call, state);
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
default:
UNREACHABLE("invalid cooperative matrix call");
}
}
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
case nir_intrinsic_cmat_load: