diff --git a/src/compiler/nir/nir_lower_cooperative_matrix.c b/src/compiler/nir/nir_lower_cooperative_matrix.c index 70f75b18bc1..20a56f92389 100644 --- a/src/compiler/nir/nir_lower_cooperative_matrix.c +++ b/src/compiler/nir/nir_lower_cooperative_matrix.c @@ -853,17 +853,14 @@ split_matrix_impl(nir_function_impl *impl, struct split_info *info) return progress; } -static struct split_mat * +static void split_var(nir_shader *shader, nir_function_impl *impl, - void *mem_ctx, - nir_variable *var, - unsigned m_gran, - unsigned n_gran, - unsigned k_gran) + struct split_info *info, + nir_variable *var) { if (!glsl_type_is_cmat(glsl_without_array(var->type))) - return NULL; + return; const struct glsl_type *type = var->type; if (glsl_type_is_array(type)) { @@ -873,7 +870,7 @@ split_var(nir_shader *shader, struct glsl_cmat_description desc = *glsl_get_cmat_description(type); unsigned split_rows = 0, split_cols = 0; - get_lower_sizes(desc, m_gran, n_gran, k_gran, &split_rows, &split_cols); + get_lower_sizes(desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols); unsigned num_row_split = 1, num_col_split = 1; @@ -887,13 +884,13 @@ split_var(nir_shader *shader, } if (num_row_split == 1 && num_col_split == 1) - return NULL; + return; const struct glsl_type *new_type = glsl_type_wrap_in_arrays(glsl_cmat_type(&desc), var->type); - struct split_mat *split_mat = ralloc(mem_ctx, struct split_mat); + struct split_mat *split_mat = ralloc(info->split_mats, struct split_mat); if (!split_mat) - return NULL; + return; unsigned num_split = num_row_split * num_col_split; split_mat->num_row_splits = num_row_split; @@ -908,27 +905,18 @@ split_var(nir_shader *shader, new_type, var->name); } } - return split_mat; + + _mesa_hash_table_insert(info->split_mats, var, split_mat); } static bool lower_dimensions(nir_shader *shader, nir_function_impl *impl, unsigned m_gran, unsigned n_gran, unsigned k_gran) { - struct hash_table *split_mats = _mesa_pointer_hash_table_create(NULL); void *mem_ctx = ralloc_context(NULL); - bool progress = false; + struct hash_table *split_mats = _mesa_pointer_hash_table_create(mem_ctx); - nir_foreach_variable_in_shader(var, shader) { - struct split_mat *split_mat = split_var(shader, NULL, mem_ctx, var, m_gran, n_gran, k_gran); - if (split_mat) - _mesa_hash_table_insert(split_mats, var, split_mat); - } - nir_foreach_function_temp_variable (var, impl) { - struct split_mat *split_mat = split_var(shader, impl, mem_ctx, var, m_gran, n_gran, k_gran); - if (split_mat) - _mesa_hash_table_insert(split_mats, var, split_mat); - } + bool progress = false; struct split_info split_info = { .split_mats = split_mats, @@ -936,6 +924,14 @@ lower_dimensions(nir_shader *shader, nir_function_impl *impl, .n_gran = n_gran, .k_gran = k_gran, }; + + nir_foreach_variable_in_shader(var, shader) { + split_var(shader, NULL, &split_info, var); + } + nir_foreach_function_temp_variable (var, impl) { + split_var(shader, impl, &split_info, var); + } + progress = split_matrix_impl(impl, &split_info); _mesa_hash_table_destroy(split_mats, NULL); ralloc_free(mem_ctx);