diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index bde6b4d0561..25ff68abf25 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -50,33 +50,65 @@ #include "brw_nir.h" +typedef struct { + /* Vector type that holds the elements packed. */ + const glsl_type *type; + + /* How many cmat elements per slice element. */ + unsigned packing_factor; + + struct glsl_cmat_description desc; + + /* Used by the tables. Variable holding a slice or + * arrays-of-arrays of slices. + * + * If present, the var->type (without arrays!) should match + * the type above. + */ + nir_variable *var; +} slice_info; + struct lower_cmat_state { + void *temp_ctx; + nir_shader *shader; - struct hash_table *slice_coop_types; + struct hash_table *slice_var_to_slice_info; - struct hash_table *vars_to_slice; + struct hash_table *mat_var_to_slice_info; unsigned subgroup_size; }; +static bool +cmat_descriptions_are_equal(struct glsl_cmat_description a, + struct glsl_cmat_description b) +{ + return a.element_type == b.element_type && + a.scope == b.scope && + a.rows == b.rows && + a.cols == b.cols && + a.use == b.use; +} + static void print_coop_types(struct lower_cmat_state *state) { fprintf(stderr, "--- Slices to Cooperative Matrix type table\n"); - hash_table_foreach(state->slice_coop_types, e) { + hash_table_foreach(state->slice_var_to_slice_info, e) { nir_variable *var = (void *)e->key; - const struct glsl_type *t = e->data; - fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t)); + const slice_info *info = e->data; + fprintf(stderr, "%p: %s -> %s\n", var, var->name, + glsl_get_type_name(glsl_cmat_type(&info->desc))); } fprintf(stderr, "\n\n"); } -static const struct glsl_type * -get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref) +static const slice_info * +get_slice_info(struct lower_cmat_state *state, nir_deref_instr *deref) { nir_variable *var = nir_deref_instr_get_variable(deref); - struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var); + struct hash_entry *entry = _mesa_hash_table_search(state->slice_var_to_slice_info, var); assert(entry != NULL); @@ -116,26 +148,10 @@ lower_cmat_filter(const nir_instr *instr, const void *_state) } } -/** - * Get number of matrix elements packed in each component of the slice. - */ -static unsigned -get_packing_factor(const struct glsl_cmat_description desc, - const struct glsl_type *slice_type) -{ - const struct glsl_type *slice_element_type = glsl_without_array(slice_type); - - assert(!glsl_type_is_cmat(slice_type)); - - assert(glsl_get_bit_size(slice_element_type) >= glsl_base_type_get_bit_size(desc.element_type)); - assert(glsl_get_bit_size(slice_element_type) % glsl_base_type_get_bit_size(desc.element_type) == 0); - - return glsl_get_bit_size(slice_element_type) / glsl_base_type_get_bit_size(desc.element_type); -} - -static const struct glsl_type * -get_slice_type_from_desc(const struct lower_cmat_state *state, - const struct glsl_cmat_description desc) +static void +init_slice_info(struct lower_cmat_state *state, + struct glsl_cmat_description desc, + slice_info *info) { enum glsl_base_type base_type; @@ -195,36 +211,9 @@ get_slice_type_from_desc(const struct lower_cmat_state *state, const struct glsl_type *slice_type = glsl_vector_type(base_type, len); - assert(packing_factor == get_packing_factor(desc, slice_type)); - - return slice_type; -} - -static const struct glsl_type * -get_slice_type(const struct lower_cmat_state *state, - const struct glsl_type *type) -{ - if (glsl_type_is_array(type)) { - const struct glsl_type *slice_type = - get_slice_type(state, glsl_get_array_element(type)); - - return glsl_array_type(slice_type, glsl_array_size(type), 0); - } - - assert(glsl_type_is_cmat(type)); - - return get_slice_type_from_desc(state, - *glsl_get_cmat_description(type)); -} - -static nir_deref_instr * -create_local_slice(struct lower_cmat_state *state, nir_builder *b, - const struct glsl_type *mat_type, const char *name) -{ - const struct glsl_type *slice_type = get_slice_type(state, mat_type); - nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name); - _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); - return nir_build_deref_var(b, slice_var); + info->type = slice_type; + info->desc = desc; + info->packing_factor = packing_factor; } static void @@ -236,12 +225,11 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, const unsigned ptr_src = load ? 1 : 0; nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]); - const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); - const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type); + const slice_info *info = get_slice_info(state, slice); + const struct glsl_cmat_description desc = info->desc; nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(slice->type); - const unsigned packing_factor = get_packing_factor(*desc, slice->type); nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]); const unsigned ptr_comp_width = glsl_get_bit_size(pointer->type); @@ -255,14 +243,14 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, nir_imul_imm(b, intrin->src[2].ssa, ptr_comp_width * ptr_num_comps), - glsl_base_type_get_bit_size(desc->element_type)); + glsl_base_type_get_bit_size(desc.element_type)); /* The data that will be packed is in successive columns for A and * accumulator matrices. The data that will be packed for B matrices is in * successive rows. */ const unsigned cols = - desc->use != GLSL_CMAT_USE_B ? desc->cols / packing_factor : desc->cols; + desc.use != GLSL_CMAT_USE_B ? desc.cols / info->packing_factor : desc.cols; nir_def *invocation = nir_load_subgroup_invocation(b); nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols); @@ -272,14 +260,14 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, const bool memory_layout_matches_register_layout = (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) == - (desc->use != GLSL_CMAT_USE_B); + (desc.use != GLSL_CMAT_USE_B); if (memory_layout_matches_register_layout) { /* In the row-major arrangement, data is loaded a dword at a time * instead of a single element at a time. For this reason the stride is * divided by the packing factor. */ - i_stride = nir_udiv_imm(b, stride, packing_factor); + i_stride = nir_udiv_imm(b, stride, info->packing_factor); } else { /* In the column-major arrangement, data is loaded a single element at a * time. Because the data elements are transposed, the step direction @@ -290,7 +278,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, * NOTE: The unscaled stride is also still needed when stepping from one * packed element to the next. This occurs in the for-j loop below. */ - i_stride = nir_imul_imm(b, stride, packing_factor); + i_stride = nir_imul_imm(b, stride, info->packing_factor); } nir_def *base_offset; @@ -341,8 +329,8 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, } } } else { - const struct glsl_type *element_type = glsl_scalar_type(desc->element_type); - const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type); + const struct glsl_type *element_type = glsl_scalar_type(desc.element_type); + const unsigned element_bits = glsl_base_type_get_bit_size(desc.element_type); const unsigned element_stride = element_bits / 8; pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type, @@ -352,7 +340,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *i_offset = nir_imul_imm(b, i_step, i); nir_def *v[4]; - for (unsigned j = 0; j < packing_factor; j++) { + for (unsigned j = 0; j < info->packing_factor; j++) { nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset); nir_deref_instr *memory_deref = @@ -376,8 +364,8 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, } if (load) { - results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor), - packing_factor * element_bits); + results[i] = nir_pack_bits(b, nir_vec(b, v, info->packing_factor), + info->packing_factor * element_bits); } } } @@ -396,28 +384,15 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - const struct glsl_type *dst_mat_type = - get_coop_type_for_slice(state, dst_slice); - const struct glsl_type *src_mat_type = - get_coop_type_for_slice(state, src_slice); + const slice_info *dst_info = get_slice_info(state, dst_slice); + const slice_info *src_info = get_slice_info(state, src_slice); - const struct glsl_cmat_description dst_desc = - *glsl_get_cmat_description(dst_mat_type); - - const struct glsl_cmat_description src_desc = - *glsl_get_cmat_description(src_mat_type); - - const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type); - const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type); + const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type); + const unsigned src_bits = glsl_base_type_bit_size(src_info->desc.element_type); /* The type of the returned slice may be different from the type of the - * input slice. + * input slice if this is a convert operation. */ - const unsigned dst_packing_factor = - get_packing_factor(dst_desc, dst_slice->type); - - const unsigned src_packing_factor = - get_packing_factor(src_desc, src_slice->type); nir_op op; @@ -427,9 +402,9 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin); enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type( - src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); + src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type( - dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); + dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type), nir_get_nir_type_for_glsl_base_type(dst_base_type), @@ -441,16 +416,16 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, * type conversion possible is int32 <-> float32. As a result * dst_packing_factor == src_packing_factor. */ - assert(dst_packing_factor == src_packing_factor); + assert(dst_info->packing_factor == src_info->packing_factor); /* Stores at most dst_packing_factor partial results. */ nir_def *v[4]; - assert(dst_packing_factor <= 4); + assert(dst_info->packing_factor <= 4); for (unsigned i = 0; i < num_components; i++) { nir_def *chan = nir_channel(b, nir_load_deref(b, src_slice), i); - for (unsigned j = 0; j < dst_packing_factor; j++) { + for (unsigned j = 0; j < dst_info->packing_factor; j++) { nir_def *src = nir_channel(b, nir_unpack_bits(b, chan, src_bits), j); @@ -458,8 +433,8 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, } results[i] = - nir_pack_bits(b, nir_vec(b, v, dst_packing_factor), - dst_packing_factor * dst_bits); + nir_pack_bits(b, nir_vec(b, v, dst_info->packing_factor), + dst_info->packing_factor * dst_bits); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), @@ -479,18 +454,14 @@ lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); - ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice); - ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice); + const slice_info *info = get_slice_info(state, dst_slice); + ASSERTED const slice_info *src_a_info = get_slice_info(state, src_a_slice); + ASSERTED const slice_info *src_b_info = get_slice_info(state, src_b_slice); - const struct glsl_cmat_description desc = - *glsl_get_cmat_description(dst_mat_type); + assert(cmat_descriptions_are_equal(info->desc, src_a_info->desc)); + assert(cmat_descriptions_are_equal(info->desc, src_b_info->desc)); - assert(dst_mat_type == src_a_mat_type); - assert(dst_mat_type == src_b_mat_type); - - const unsigned bits = glsl_base_type_bit_size(desc.element_type); - const unsigned packing_factor = get_packing_factor(desc, dst_slice->type); + const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); for (unsigned i = 0; i < num_components; i++) { nir_def *val_a = nir_channel(b, src_a, i); @@ -500,7 +471,7 @@ lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin), nir_unpack_bits(b, val_a, bits), nir_unpack_bits(b, val_b, bits)), - packing_factor * bits); + info->packing_factor * bits); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), @@ -519,15 +490,11 @@ lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); - ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice); - assert(dst_mat_type == src_mat_type); + const slice_info *info = get_slice_info(state, dst_slice); + ASSERTED const slice_info *src_info = get_slice_info(state, src_slice); + assert(cmat_descriptions_are_equal(info->desc, src_info->desc)); - const struct glsl_cmat_description desc = - *glsl_get_cmat_description(dst_mat_type); - - const unsigned bits = glsl_base_type_bit_size(desc.element_type); - const unsigned packing_factor = get_packing_factor(desc, dst_slice->type); + const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); for (unsigned i = 0; i < num_components; i++) { nir_def *val = nir_channel(b, src, i); @@ -536,7 +503,7 @@ lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin), nir_unpack_bits(b, val, bits), scalar), - packing_factor * bits); + info->packing_factor * bits); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), @@ -557,9 +524,10 @@ lower_cmat_deref(nir_builder *b, nir_deref_instr *deref, assert(deref->var); assert(glsl_type_is_cmat(glsl_without_array(deref->var->type))); - struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var); + struct hash_entry *entry = _mesa_hash_table_search(state->mat_var_to_slice_info, deref->var); assert(entry); - return nir_build_deref_var(b, (nir_variable *)entry->data); + const slice_info *info = entry->data; + return nir_build_deref_var(b, info->var); } } @@ -584,14 +552,11 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); nir_def *src = intrin->src[1].ssa; - const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); - const struct glsl_cmat_description desc = - *glsl_get_cmat_description(mat_type); - const unsigned packing_factor = get_packing_factor(desc, slice->type); + const slice_info *info = get_slice_info(state, slice); - if (packing_factor > 1) { - src = nir_pack_bits(b, nir_replicate(b, src, packing_factor), - packing_factor * glsl_base_type_get_bit_size(desc.element_type)); + if (info->packing_factor > 1) { + src = nir_pack_bits(b, nir_replicate(b, src, info->packing_factor), + info->packing_factor * glsl_base_type_get_bit_size(info->desc.element_type)); } const unsigned num_components = glsl_get_vector_elements(slice->type); @@ -615,11 +580,10 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_length: { - const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin); - const struct glsl_type *mat_type = glsl_cmat_type(&desc); - const struct glsl_type *slice_type = get_slice_type(state, mat_type); - return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) * - glsl_get_vector_elements(slice_type)), 32); + slice_info info = {}; + init_slice_info(state, nir_intrinsic_cmat_desc(intrin), &info); + return nir_imm_intN_t(b, info.packing_factor * + glsl_get_vector_elements(info.type), 32); } case nir_intrinsic_cmat_muladd: { @@ -628,13 +592,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]); nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]); - const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); - const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type); + const slice_info *dst_info = get_slice_info(state, dst_slice); + const slice_info *src_info = get_slice_info(state, A_slice); - const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice); - const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type); - - const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type); const unsigned num_components = glsl_get_vector_elements(dst_slice->type); const nir_cmat_signed cmat_signed_mask = @@ -647,8 +607,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0)); - enum glsl_base_type src_type = src_desc.element_type; - enum glsl_base_type dst_type = dst_desc.element_type; + enum glsl_base_type src_type = src_info->desc.element_type; + enum glsl_base_type dst_type = dst_info->desc.element_type; /* For integer types, the signedness is determined by flags on the * muladd instruction. The types of the sources play no role. Adjust the @@ -666,7 +626,7 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_def *result = nir_dpas_intel(b, - packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type), + dst_info->packing_factor * glsl_base_type_get_bit_size(dst_info->desc.element_type), nir_load_deref(b, accum_slice), nir_load_deref(b, A_slice), nir_load_deref(b, B_slice), @@ -707,23 +667,19 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]); const nir_src dst_index = intrin->src[3]; - const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); - ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice); - assert(dst_mat_type == src_mat_type); + const slice_info *info = get_slice_info(state, dst_slice); + ASSERTED const slice_info *src_info = get_slice_info(state, src_slice); + assert(cmat_descriptions_are_equal(info->desc, src_info->desc)); - const struct glsl_cmat_description desc = - *glsl_get_cmat_description(dst_mat_type); - - const unsigned bits = glsl_base_type_bit_size(desc.element_type); - const unsigned packing_factor = get_packing_factor(desc, dst_slice->type); + const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); const unsigned num_components = glsl_get_vector_elements(dst_slice->type); - nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor); - nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor); + nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, info->packing_factor); + nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, info->packing_factor); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const int slice_constant_index = nir_src_is_const(dst_index) - ? nir_src_as_uint(dst_index) / packing_factor + ? nir_src_as_uint(dst_index) / info->packing_factor : -1; for (unsigned i = 0; i < num_components; i++) { @@ -731,13 +687,13 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) nir_def *insert; if (slice_constant_index < 0 || slice_constant_index == i) { - if (packing_factor == 1) { + if (info->packing_factor == 1) { insert = scalar; } else { nir_def *unpacked = nir_unpack_bits(b, val, bits); nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index); - insert = nir_pack_bits(b, v, bits * packing_factor); + insert = nir_pack_bits(b, v, bits * info->packing_factor); } } else { insert = val; @@ -756,25 +712,21 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) case nir_intrinsic_cmat_extract: { nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); - const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); + const slice_info *info = get_slice_info(state, slice); nir_def *index = intrin->src[1].ssa; - const struct glsl_cmat_description desc = - *glsl_get_cmat_description(mat_type); - - const unsigned bits = glsl_base_type_bit_size(desc.element_type); - const unsigned packing_factor = get_packing_factor(desc, slice->type); + const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); nir_def *src = nir_vector_extract(b, nir_load_deref(b, slice), - nir_udiv_imm(b, index, packing_factor)); + nir_udiv_imm(b, index, info->packing_factor)); - if (packing_factor == 1) { + if (info->packing_factor == 1) { return src; } else { return nir_vector_extract(b, nir_unpack_bits(b, src, bits), - nir_umod_imm(b, index, packing_factor)); + nir_umod_imm(b, index, info->packing_factor)); } return NIR_LOWER_INSTR_PROGRESS_REPLACE; @@ -785,25 +737,40 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) } } +static const glsl_type * +make_aoa_slice_type(const glsl_type *t, const glsl_type *slice_type) +{ + if (glsl_type_is_array(t)) { + const glsl_type *s = make_aoa_slice_type(glsl_get_array_element(t), slice_type); + return glsl_array_type(s, glsl_array_size(t), 0); + } + + assert(glsl_type_is_cmat(t)); + return slice_type; +} + static void create_slice_var(struct lower_cmat_state *state, nir_variable *var, nir_function_impl *impl) { - // TODO: without array const struct glsl_type *mat_type = glsl_without_array(var->type); assert(glsl_type_is_cmat(mat_type)); assert((!impl && var->data.mode == nir_var_shader_temp) || ( impl && var->data.mode == nir_var_function_temp)); - const struct glsl_type *slice_type = get_slice_type(state, var->type); - const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name); - nir_variable *slice_var = impl ? - nir_local_variable_create(impl, slice_type, slice_name) : - nir_variable_create(state->shader, var->data.mode, slice_type, slice_name); + slice_info *info = rzalloc(state->temp_ctx, slice_info); + init_slice_info(state, *glsl_get_cmat_description(mat_type), info); - _mesa_hash_table_insert(state->vars_to_slice, var, slice_var); - _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); + const glsl_type *aoa_slice_type = make_aoa_slice_type(var->type, info->type); + + const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name); + info->var = impl ? + nir_local_variable_create(impl, aoa_slice_type, slice_name) : + nir_variable_create(state->shader, var->data.mode, aoa_slice_type, slice_name); + + _mesa_hash_table_insert(state->mat_var_to_slice_info, var, info); + _mesa_hash_table_insert(state->slice_var_to_slice_info, info->var, info); } bool @@ -812,9 +779,10 @@ brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size) void *temp_ctx = ralloc_context(NULL); struct lower_cmat_state state = { + .temp_ctx = temp_ctx, .shader = shader, - .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx), - .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx), + .slice_var_to_slice_info = _mesa_pointer_hash_table_create(temp_ctx), + .mat_var_to_slice_info = _mesa_pointer_hash_table_create(temp_ctx), .subgroup_size = subgroup_size, };