nir/coopmat: refactor the split vars to clean it up

This just moves to using the split_info to store all the info,
and updating the hash table inside the split function.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41500>
This commit is contained in:
David Airlie 2026-05-12 12:19:25 +10:00 committed by Marge Bot
parent 510998e493
commit eaf6207e06

View file

@ -853,17 +853,14 @@ split_matrix_impl(nir_function_impl *impl, struct split_info *info)
return progress;
}
static struct split_mat *
static void
split_var(nir_shader *shader,
nir_function_impl *impl,
void *mem_ctx,
nir_variable *var,
unsigned m_gran,
unsigned n_gran,
unsigned k_gran)
struct split_info *info,
nir_variable *var)
{
if (!glsl_type_is_cmat(glsl_without_array(var->type)))
return NULL;
return;
const struct glsl_type *type = var->type;
if (glsl_type_is_array(type)) {
@ -873,7 +870,7 @@ split_var(nir_shader *shader,
struct glsl_cmat_description desc = *glsl_get_cmat_description(type);
unsigned split_rows = 0, split_cols = 0;
get_lower_sizes(desc, m_gran, n_gran, k_gran, &split_rows, &split_cols);
get_lower_sizes(desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols);
unsigned num_row_split = 1, num_col_split = 1;
@ -887,13 +884,13 @@ split_var(nir_shader *shader,
}
if (num_row_split == 1 && num_col_split == 1)
return NULL;
return;
const struct glsl_type *new_type = glsl_type_wrap_in_arrays(glsl_cmat_type(&desc), var->type);
struct split_mat *split_mat = ralloc(mem_ctx, struct split_mat);
struct split_mat *split_mat = ralloc(info->split_mats, struct split_mat);
if (!split_mat)
return NULL;
return;
unsigned num_split = num_row_split * num_col_split;
split_mat->num_row_splits = num_row_split;
@ -908,27 +905,18 @@ split_var(nir_shader *shader,
new_type, var->name);
}
}
return split_mat;
_mesa_hash_table_insert(info->split_mats, var, split_mat);
}
static bool
lower_dimensions(nir_shader *shader, nir_function_impl *impl,
unsigned m_gran, unsigned n_gran, unsigned k_gran)
{
struct hash_table *split_mats = _mesa_pointer_hash_table_create(NULL);
void *mem_ctx = ralloc_context(NULL);
bool progress = false;
struct hash_table *split_mats = _mesa_pointer_hash_table_create(mem_ctx);
nir_foreach_variable_in_shader(var, shader) {
struct split_mat *split_mat = split_var(shader, NULL, mem_ctx, var, m_gran, n_gran, k_gran);
if (split_mat)
_mesa_hash_table_insert(split_mats, var, split_mat);
}
nir_foreach_function_temp_variable (var, impl) {
struct split_mat *split_mat = split_var(shader, impl, mem_ctx, var, m_gran, n_gran, k_gran);
if (split_mat)
_mesa_hash_table_insert(split_mats, var, split_mat);
}
bool progress = false;
struct split_info split_info = {
.split_mats = split_mats,
@ -936,6 +924,14 @@ lower_dimensions(nir_shader *shader, nir_function_impl *impl,
.n_gran = n_gran,
.k_gran = k_gran,
};
nir_foreach_variable_in_shader(var, shader) {
split_var(shader, NULL, &split_info, var);
}
nir_foreach_function_temp_variable (var, impl) {
split_var(shader, impl, &split_info, var);
}
progress = split_matrix_impl(impl, &split_info);
_mesa_hash_table_destroy(split_mats, NULL);
ralloc_free(mem_ctx);