/* * Copyright 2023 Intel Corporation * SPDX-License-Identifier: MIT */ /** * \file brw_nir_lower_cooperative_matrix.c * Lower cooperative matrix to subgroup operations. */ #include "brw_nir.h" struct lower_cmat_state { nir_shader *shader; struct hash_table *slice_coop_types; struct hash_table *vars_to_slice; unsigned subgroup_size; }; static void print_coop_types(struct lower_cmat_state *state) { fprintf(stderr, "--- Slices to Cooperative Matrix type table\n"); hash_table_foreach(state->slice_coop_types, e) { nir_variable *var = (void *)e->key; const struct glsl_type *t = e->data; fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t)); } fprintf(stderr, "\n\n"); } static const struct glsl_type * get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref) { nir_variable *var = nir_deref_instr_get_variable(deref); struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var); assert(entry != NULL); return entry->data; } static bool lower_cmat_filter(const nir_instr *instr, const void *_state) { if (instr->type == nir_instr_type_deref) { nir_deref_instr *deref = nir_instr_as_deref(instr); return glsl_type_is_cmat(deref->type); } if (instr->type != nir_instr_type_intrinsic) return false; nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); switch (intrin->intrinsic) { case nir_intrinsic_cmat_construct: case nir_intrinsic_cmat_load: case nir_intrinsic_cmat_store: case nir_intrinsic_cmat_length: case nir_intrinsic_cmat_muladd: case nir_intrinsic_cmat_unary_op: case nir_intrinsic_cmat_binary_op: case nir_intrinsic_cmat_scalar_op: case nir_intrinsic_cmat_bitcast: case nir_intrinsic_cmat_insert: case nir_intrinsic_cmat_extract: case nir_intrinsic_cmat_copy: return true; default: return false; } } static const struct glsl_type * get_slice_type(const struct lower_cmat_state *state, const struct glsl_type *type) { if (glsl_type_is_array(type)) { const struct glsl_type *slice_type = get_slice_type(state, glsl_get_array_element(type)); return glsl_array_type(slice_type, glsl_array_size(type), 0); } assert(glsl_type_is_cmat(type)); const struct glsl_cmat_description *desc = glsl_get_cmat_description(type); unsigned int len = (desc->rows * desc->cols) / state->subgroup_size; assert(len > 0); return glsl_vector_type(desc->element_type, len); } static nir_deref_instr * create_local_slice(struct lower_cmat_state *state, nir_builder *b, const struct glsl_type *mat_type, const char *name) { const struct glsl_type *slice_type = get_slice_type(state, mat_type); nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name); _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); return nir_build_deref_var(b, slice_var); } static void lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { const bool load = intrin->intrinsic == nir_intrinsic_cmat_load; const unsigned mat_src = load ? 0 : 1; const unsigned ptr_src = load ? 1 : 0; /* TODO: Column major. */ assert(nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR); nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]); const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type); /* TODO: Dynamic stride. */ assert(nir_src_is_const(intrin->src[2])); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(slice->type); nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]); const unsigned stride = nir_src_as_uint(intrin->src[2]); const struct glsl_type *element_type = glsl_get_array_element(slice->type); const struct glsl_type *pointer_type = glsl_array_type(element_type, MAX2(desc->rows, desc->cols) * stride, glsl_get_bit_size(element_type) / 8); pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, pointer_type, glsl_get_bit_size(element_type) / 8); for (unsigned i = 0; i < num_components; i++) { nir_def *offset = nir_imul_imm(b, nir_load_subgroup_invocation(b), stride); nir_deref_instr *memory_deref = nir_build_deref_array(b, pointer, nir_i2iN(b, nir_iadd_imm(b, offset, i), pointer->def.bit_size)); if (load) { results[i] = nir_load_deref(b, memory_deref); } else { nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); nir_store_deref(b, memory_deref, src, 0x1); } } if (load) nir_store_deref(b, slice, nir_vec(b, results, num_components), 0xffff); } static void lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_def *src = nir_load_deref(b, nir_src_as_deref(intrin->src[1])); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); for (unsigned i = 0; i < num_components; i++) { nir_def *val = nir_channel(b, src, i); results[i] = nir_build_alu1(b, nir_intrinsic_alu_op(intrin), val); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), nir_component_mask(num_components)); } static void lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]); nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]); nir_def *src_a = nir_load_deref(b, src_a_slice); nir_def *src_b = nir_load_deref(b, src_b_slice); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice); ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice); assert(dst_mat_type == src_a_mat_type); assert(dst_mat_type == src_b_mat_type); for (unsigned i = 0; i < num_components; i++) { nir_def *val_a = nir_channel(b, src_a, i); nir_def *val_b = nir_channel(b, src_b, i); results[i] = nir_build_alu2(b, nir_intrinsic_alu_op(intrin), val_a, val_b); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), nir_component_mask(num_components)); } static void lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); nir_def *scalar = intrin->src[2].ssa; nir_def *src = nir_load_deref(b, src_slice); nir_def *results[NIR_MAX_VEC_COMPONENTS]; const unsigned num_components = glsl_get_vector_elements(dst_slice->type); ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice); assert(dst_mat_type == src_mat_type); for (unsigned i = 0; i < num_components; i++) { nir_def *val = nir_channel(b, src, i); results[i] = nir_build_alu2(b, nir_intrinsic_alu_op(intrin), val, scalar); } nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), nir_component_mask(num_components)); } static nir_deref_instr * lower_cmat_deref(nir_builder *b, nir_deref_instr *deref, struct lower_cmat_state *state) { nir_deref_instr *parent = nir_deref_instr_parent(deref); if (parent) { assert(deref->deref_type == nir_deref_type_array); parent = lower_cmat_deref(b, parent, state); return nir_build_deref_array(b, parent, deref->arr.index.ssa); } else { assert(deref->deref_type == nir_deref_type_var); assert(deref->var); assert(glsl_type_is_cmat(glsl_without_array(deref->var->type))); struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var); assert(entry); return nir_build_deref_var(b, (nir_variable *)entry->data); } } static nir_def * lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) { struct lower_cmat_state *state = _state; if (instr->type == nir_instr_type_deref) { nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state); return &deref->def; } nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); switch (intrin->intrinsic) { case nir_intrinsic_cmat_load: case nir_intrinsic_cmat_store: lower_cmat_load_store(b, intrin, state); return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_construct: { nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); nir_def *src = intrin->src[1].ssa; const unsigned num_components = glsl_get_vector_elements(slice->type); nir_store_deref(b, slice, nir_replicate(b, src, num_components), nir_component_mask(num_components)); return NIR_LOWER_INSTR_PROGRESS_REPLACE; } case nir_intrinsic_cmat_unary_op: lower_cmat_unary_op(b, intrin, state); return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_binary_op: lower_cmat_binary_op(b, intrin, state); return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_scalar_op: lower_cmat_scalar_op(b, intrin, state); return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_length: { const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin); const struct glsl_type *mat_type = glsl_cmat_type(&desc); const struct glsl_type *slice_type = get_slice_type(state, mat_type); return nir_imm_intN_t(b, glsl_get_vector_elements(slice_type), 32); } case nir_intrinsic_cmat_muladd: /* FINISHME. */ return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_bitcast: case nir_intrinsic_cmat_insert: case nir_intrinsic_cmat_extract: /* FINISHME. */ return NIR_LOWER_INSTR_PROGRESS_REPLACE; case nir_intrinsic_cmat_copy: nir_copy_deref(b, nir_src_as_deref(intrin->src[0]), nir_src_as_deref(intrin->src[1])); return NIR_LOWER_INSTR_PROGRESS_REPLACE; default: unreachable("invalid cooperative matrix intrinsic"); } } static void create_slice_var(struct lower_cmat_state *state, nir_variable *var, nir_function_impl *impl) { // TODO: without array const struct glsl_type *mat_type = glsl_without_array(var->type); assert(glsl_type_is_cmat(mat_type)); assert((!impl && var->data.mode == nir_var_shader_temp) || ( impl && var->data.mode == nir_var_function_temp)); const struct glsl_type *slice_type = get_slice_type(state, var->type); const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name); nir_variable *slice_var = impl ? nir_local_variable_create(impl, slice_type, slice_name) : nir_variable_create(state->shader, var->data.mode, slice_type, slice_name); _mesa_hash_table_insert(state->vars_to_slice, var, slice_var); _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); } bool brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size) { void *temp_ctx = ralloc_context(NULL); struct lower_cmat_state state = { .shader = shader, .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx), .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx), .subgroup_size = subgroup_size, }; /* Create a slice array for each variable and add a map from the original * variable back to it, so it can be reached during lowering. * * TODO: Cooperative matrix inside struct? */ nir_foreach_variable_in_shader(var, shader) { if (glsl_type_is_cmat(glsl_without_array(var->type))) create_slice_var(&state, var, NULL); } nir_foreach_function(func, shader) { nir_foreach_function_temp_variable(var, func->impl) { if (glsl_type_is_cmat(glsl_without_array(var->type))) create_slice_var(&state, var, func->impl); } } bool progress = nir_shader_lower_instructions(shader, lower_cmat_filter, lower_cmat_instr, &state); ralloc_free(temp_ctx); return progress; }