nak: extract cmat load/store element offset calculation

Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Fixes: 05dca16143 ("nak: extract nir_intrinsic_cmat_load lowering into a function")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37941>
This commit is contained in:
Karol Herbst 2025-10-20 14:58:35 +02:00 committed by Marge Bot
parent d423554e9e
commit f632bfc715

View file

@ -671,6 +671,38 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
.matrix_layout = layout);
}
static nir_deref_instr*
get_cmat_component_deref(nir_builder *b, nir_intrinsic_instr *intr,
nir_def *lane_id, unsigned idx)
{
unsigned deref_src = intr->intrinsic == nir_intrinsic_cmat_store ? 0 : 1;
unsigned cmat_src = intr->intrinsic == nir_intrinsic_cmat_store ? 1 : 0;
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[cmat_src]);
nir_deref_instr *deref = nir_def_as_deref(intr->src[deref_src].ssa);
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_def *stride = intr->src[2].ssa;
nir_def *col_offset;
nir_def *row_offset;
compute_matrix_offsets(b, desc, layout, lane_id, idx,
&col_offset, &row_offset);
row_offset = nir_imul(b, row_offset, stride);
col_offset = nir_u2uN(b, col_offset, deref->def.bit_size);
row_offset = nir_u2uN(b, row_offset, deref->def.bit_size);
nir_deref_instr *iter_deref =
nir_build_deref_ptr_as_array(b, deref, row_offset);
iter_deref = nir_build_deref_cast(
b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
return nir_build_deref_ptr_as_array(b, iter_deref, col_offset);
}
static void
lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
{
@ -682,10 +714,6 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[0]);
const unsigned length = get_cmat_length(desc);
const enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_deref_instr *deref = nir_def_as_deref(intr->src[1].ssa);
nir_def *stride = intr->src[2].ssa;
nir_def *vars[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < length; ++i)
@ -694,26 +722,8 @@ lower_cmat_load(nir_builder *b, nir_intrinsic_instr *intr)
nir_def *lane_id = nir_load_subgroup_invocation(b);
for (unsigned idx = 0; idx < length; idx++) {
nir_def *col_offset;
nir_def *row_offset;
compute_matrix_offsets(b, desc, layout, lane_id, idx,
&col_offset, &row_offset);
row_offset = nir_imul(b, row_offset, stride);
col_offset = nir_u2uN(b, col_offset, deref->def.bit_size);
row_offset = nir_u2uN(b, row_offset, deref->def.bit_size);
nir_deref_instr *iter_deref =
nir_build_deref_ptr_as_array(b, deref, row_offset);
iter_deref = nir_build_deref_cast(
b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
iter_deref =
nir_build_deref_ptr_as_array(b, iter_deref, col_offset);
get_cmat_component_deref(b, intr, lane_id, idx);
vars[idx] = nir_load_deref(b, iter_deref);
}
@ -764,11 +774,6 @@ lower_cmat_instr(nir_builder *b,
}
case nir_intrinsic_cmat_store: {
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_def *stride = intr->src[2].ssa;
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[1]);
const unsigned length = get_cmat_length(desc);
nir_def *src = load_cmat_src(b, intr->src[1]);
@ -780,26 +785,8 @@ lower_cmat_instr(nir_builder *b,
nir_def *lane_id = nir_load_subgroup_invocation(b);
for (unsigned idx = 0; idx < length; idx++) {
nir_def *col_offset;
nir_def *row_offset;
compute_matrix_offsets(b, desc, layout, lane_id, idx,
&col_offset, &row_offset);
row_offset = nir_imul(b, row_offset, stride);
col_offset = nir_u2uN(b, col_offset, deref->def.bit_size);
row_offset = nir_u2uN(b, row_offset, deref->def.bit_size);
nir_deref_instr *iter_deref =
nir_build_deref_ptr_as_array(b, deref, row_offset);
iter_deref = nir_build_deref_cast(
b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
iter_deref =
nir_build_deref_ptr_as_array(b, iter_deref, col_offset);
get_cmat_component_deref(b, intr, lane_id, idx);
nir_store_deref(b, iter_deref, vars[idx], 1);
}