mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-04-25 04:50:38 +02:00
brw: Add lowering for nir_cmat_call_op_per_element_op
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Reviewed-by: Dave Airlie <airlied@redhat.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39904>
This commit is contained in:
parent
095c470d25
commit
ffc3219d57
1 changed files with 158 additions and 0 deletions
|
|
@ -129,6 +129,9 @@ lower_cmat_filter(const nir_instr *instr, const void *_state)
|
|||
return glsl_type_is_cmat(deref->type);
|
||||
}
|
||||
|
||||
if (instr->type == nir_instr_type_cmat_call)
|
||||
return true;
|
||||
|
||||
if (instr->type != nir_instr_type_intrinsic)
|
||||
return false;
|
||||
|
||||
|
|
@ -574,6 +577,150 @@ lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
nir_component_mask(num_components));
|
||||
}
|
||||
|
||||
static void
|
||||
lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call,
|
||||
struct lower_cmat_state *state)
|
||||
{
|
||||
nir_deref_instr *dst_slice = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src_slice = nir_src_as_deref(call->params[3]);
|
||||
const slice_info *dst_info = get_slice_info(state, dst_slice);
|
||||
const slice_info *src_info = get_slice_info(state, src_slice);
|
||||
assert(cmat_descriptions_are_equal(dst_info->desc, src_info->desc));
|
||||
|
||||
nir_def *row_offset = call->params[1].ssa;
|
||||
nir_def *col_offset = call->params[2].ssa;
|
||||
|
||||
struct extra_param {
|
||||
nir_def *def;
|
||||
const slice_info *slice_info;
|
||||
};
|
||||
|
||||
const unsigned extra_params_count = call->num_params - 4;
|
||||
struct extra_param *extra_params =
|
||||
rzalloc_array(state->temp_ctx, struct extra_param, extra_params_count);
|
||||
|
||||
for (unsigned p = 0; p < extra_params_count; p++) {
|
||||
const nir_src param_src = call->params[4 + p];
|
||||
struct extra_param *extra = &extra_params[p];
|
||||
extra->def = param_src.ssa;
|
||||
|
||||
nir_deref_instr *deref = nir_src_as_deref(param_src);
|
||||
if (deref) {
|
||||
nir_variable *var = nir_deref_instr_get_variable(deref);
|
||||
struct hash_entry *entry =
|
||||
_mesa_hash_table_search(state->slice_var_to_slice_info, var);
|
||||
if (entry) {
|
||||
extra->slice_info = entry->data;
|
||||
extra->def = nir_load_deref(b, deref);
|
||||
assert(cmat_descriptions_are_equal(src_info->desc, extra->slice_info->desc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const struct glsl_cmat_description desc = dst_info->desc;
|
||||
const unsigned bits = glsl_base_type_bit_size(desc.element_type);
|
||||
const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
|
||||
const unsigned packing_factor = dst_info->packing_factor;
|
||||
nir_def *src = nir_load_deref(b, src_slice);
|
||||
|
||||
nir_def *invocation = nir_load_subgroup_invocation(b);
|
||||
const unsigned cols = desc.use != GLSL_CMAT_USE_B ? desc.cols / packing_factor : desc.cols;
|
||||
const unsigned row_step = state->subgroup_size / cols;
|
||||
nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols);
|
||||
nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols);
|
||||
|
||||
nir_deref_instr *elem_deref =
|
||||
nir_build_deref_var(b, nir_local_variable_create(b->impl, glsl_scalar_type(desc.element_type),
|
||||
"cmat_per_element_tmp"));
|
||||
|
||||
nir_deref_instr *result_deref =
|
||||
nir_build_deref_var(b, nir_local_variable_create(b->impl, dst_slice->type,
|
||||
"cmat_per_element_result"));
|
||||
nir_store_deref(b, result_deref,
|
||||
nir_undef(b, num_components, glsl_get_bit_size(dst_slice->type)),
|
||||
nir_component_mask(num_components));
|
||||
|
||||
nir_deref_instr *iter_deref =
|
||||
nir_build_deref_var(b, nir_local_variable_create(b->impl, glsl_uint_type(),
|
||||
"cmat_per_element_iter"));
|
||||
nir_store_deref(b, iter_deref, nir_imm_int(b, 0), 0x1);
|
||||
|
||||
nir_loop *loop = nir_push_loop(b);
|
||||
{
|
||||
nir_def *iter = nir_load_deref(b, iter_deref);
|
||||
nir_break_if(b, nir_uge_imm(b, iter, num_components));
|
||||
|
||||
nir_def *row_group =
|
||||
nir_iadd(b, invocation_div_cols, nir_imul_imm(b, iter, row_step));
|
||||
|
||||
nir_def *row_base = row_group;
|
||||
nir_def *col_base = invocation_mod_cols;
|
||||
|
||||
if (desc.use == GLSL_CMAT_USE_B)
|
||||
row_base = nir_imul_imm(b, row_base, packing_factor);
|
||||
else
|
||||
col_base = nir_imul_imm(b, col_base, packing_factor);
|
||||
|
||||
nir_def *packed = nir_vector_extract(b, src, iter);
|
||||
nir_def *unpacked = packing_factor > 1 ? nir_unpack_bits(b, packed, bits) : packed;
|
||||
nir_def *vals[BRW_MAX_PACKING_FACTOR] = {0};
|
||||
|
||||
for (unsigned j = 0; j < packing_factor; j++) {
|
||||
nir_def *row_val = nir_iadd(b, row_base, row_offset);
|
||||
nir_def *col_val = nir_iadd(b, col_base, col_offset);
|
||||
|
||||
if (desc.use == GLSL_CMAT_USE_B)
|
||||
row_val = nir_iadd_imm(b, row_val, j);
|
||||
else
|
||||
col_val = nir_iadd_imm(b, col_val, j);
|
||||
|
||||
nir_def *src_elem = packing_factor > 1 ? nir_channel(b, unpacked, j) : unpacked;
|
||||
|
||||
nir_call_instr *new_call = nir_call_instr_create(b->shader, call->callee);
|
||||
new_call->params[0] = nir_src_for_ssa(&elem_deref->def);
|
||||
new_call->params[1] = nir_src_for_ssa(row_val);
|
||||
new_call->params[2] = nir_src_for_ssa(col_val);
|
||||
new_call->params[3] = nir_src_for_ssa(src_elem);
|
||||
|
||||
for (unsigned p = 0; p < extra_params_count; p++) {
|
||||
nir_def *def = extra_params[p].def;
|
||||
|
||||
/* Any additional cooperative matrix operands have the corresponding
|
||||
* matrix element passed to the function.
|
||||
*/
|
||||
const slice_info *info = extra_params[p].slice_info;
|
||||
if (info) {
|
||||
def = nir_vector_extract(b, def, iter);
|
||||
if (info->packing_factor > 1) {
|
||||
const unsigned param_bit_size =
|
||||
glsl_base_type_bit_size(info->desc.element_type);
|
||||
def = nir_channel(b, nir_unpack_bits(b, def, param_bit_size), j);
|
||||
}
|
||||
}
|
||||
|
||||
new_call->params[4 + p] = nir_src_for_ssa(def);
|
||||
}
|
||||
|
||||
nir_builder_instr_insert(b, &new_call->instr);
|
||||
vals[j] = nir_load_deref(b, elem_deref);
|
||||
}
|
||||
|
||||
packed =
|
||||
packing_factor > 1 ? nir_pack_bits(b, nir_vec(b, vals, packing_factor),
|
||||
packing_factor * bits)
|
||||
: vals[0];
|
||||
nir_def *new_vec = nir_vector_insert(b, nir_load_deref(b, result_deref),
|
||||
packed, iter);
|
||||
nir_store_deref(b, result_deref, new_vec, nir_component_mask(num_components));
|
||||
|
||||
nir_store_deref(b, iter_deref, nir_iadd_imm(b, iter, 1), 0x1);
|
||||
}
|
||||
nir_pop_loop(b, loop);
|
||||
|
||||
nir_store_deref(b, dst_slice, nir_load_deref(b, result_deref),
|
||||
nir_component_mask(num_components));
|
||||
}
|
||||
|
||||
static nir_deref_instr *
|
||||
lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
|
||||
struct lower_cmat_state *state)
|
||||
|
|
@ -605,6 +752,17 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
|
|||
return &deref->def;
|
||||
}
|
||||
|
||||
if (instr->type == nir_instr_type_cmat_call) {
|
||||
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
|
||||
switch (call->op) {
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
lower_cmat_per_element_op(b, call, state);
|
||||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
default:
|
||||
UNREACHABLE("invalid cooperative matrix call");
|
||||
}
|
||||
}
|
||||
|
||||
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
|
||||
switch (intrin->intrinsic) {
|
||||
case nir_intrinsic_cmat_load:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue