diff --git a/src/intel/compiler/brw_nir.h b/src/intel/compiler/brw_nir.h index c93c64223c8..119de1c6086 100644 --- a/src/intel/compiler/brw_nir.h +++ b/src/intel/compiler/brw_nir.h @@ -189,6 +189,8 @@ void brw_nir_lower_fs_outputs(nir_shader *nir); bool brw_nir_lower_conversions(nir_shader *nir); +bool brw_nir_lower_cmat(nir_shader *nir, unsigned subgroup_size); + bool brw_nir_lower_shading_rate_output(nir_shader *nir); bool brw_nir_lower_sparse_intrinsics(nir_shader *nir); diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c new file mode 100644 index 00000000000..69fabe79fcf --- /dev/null +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -0,0 +1,381 @@ +/* + * Copyright 2023 Intel Corporation + * SPDX-License-Identifier: MIT + */ + +/** + * \file brw_nir_lower_cooperative_matrix.c + * Lower cooperative matrix to subgroup operations. + */ + +#include "brw_nir.h" + +struct lower_cmat_state { + nir_shader *shader; + + struct hash_table *slice_coop_types; + + struct hash_table *vars_to_slice; + + unsigned subgroup_size; +}; + +static void +print_coop_types(struct lower_cmat_state *state) +{ + fprintf(stderr, "--- Slices to Cooperative Matrix type table\n"); + hash_table_foreach(state->slice_coop_types, e) { + nir_variable *var = (void *)e->key; + const struct glsl_type *t = e->data; + fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t)); + } + fprintf(stderr, "\n\n"); +} + +static const struct glsl_type * +get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref) +{ + nir_variable *var = nir_deref_instr_get_variable(deref); + struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var); + + assert(entry != NULL); + + return entry->data; +} + +static bool +lower_cmat_filter(const nir_instr *instr, const void *_state) +{ + if (instr->type == nir_instr_type_deref) { + nir_deref_instr *deref = nir_instr_as_deref(instr); + return glsl_type_is_cmat(deref->type); + } + + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_cmat_construct: + case nir_intrinsic_cmat_load: + case nir_intrinsic_cmat_store: + case nir_intrinsic_cmat_length: + case nir_intrinsic_cmat_muladd: + case nir_intrinsic_cmat_unary_op: + case nir_intrinsic_cmat_binary_op: + case nir_intrinsic_cmat_scalar_op: + case nir_intrinsic_cmat_bitcast: + case nir_intrinsic_cmat_insert: + case nir_intrinsic_cmat_extract: + case nir_intrinsic_cmat_copy: + return true; + + default: + return false; + } +} + +static const struct glsl_type * +get_slice_type(const struct lower_cmat_state *state, + const struct glsl_type *type) +{ + if (glsl_type_is_array(type)) { + const struct glsl_type *slice_type = + get_slice_type(state, glsl_get_array_element(type)); + + return glsl_array_type(slice_type, glsl_array_size(type), 0); + } + + assert(glsl_type_is_cmat(type)); + const struct glsl_cmat_description *desc = glsl_get_cmat_description(type); + unsigned int len = (desc->rows * desc->cols) / state->subgroup_size; + assert(len > 0); + return glsl_vector_type(desc->element_type, len); +} + +static nir_deref_instr * +create_local_slice(struct lower_cmat_state *state, nir_builder *b, + const struct glsl_type *mat_type, const char *name) +{ + const struct glsl_type *slice_type = get_slice_type(state, mat_type); + nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name); + _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); + return nir_build_deref_var(b, slice_var); +} + +static void +lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, + struct lower_cmat_state *state) +{ + const bool load = intrin->intrinsic == nir_intrinsic_cmat_load; + const unsigned mat_src = load ? 0 : 1; + const unsigned ptr_src = load ? 1 : 0; + + /* TODO: Column major. */ + assert(nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR); + + nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]); + const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice); + const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type); + + /* TODO: Dynamic stride. */ + assert(nir_src_is_const(intrin->src[2])); + + nir_def *results[NIR_MAX_VEC_COMPONENTS]; + const unsigned num_components = glsl_get_vector_elements(slice->type); + + nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]); + + const unsigned stride = nir_src_as_uint(intrin->src[2]); + + const struct glsl_type *element_type = + glsl_get_array_element(slice->type); + + const struct glsl_type *pointer_type = + glsl_array_type(element_type, MAX2(desc->rows, desc->cols) * stride, + glsl_get_bit_size(element_type) / 8); + + pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, pointer_type, + glsl_get_bit_size(element_type) / 8); + + for (unsigned i = 0; i < num_components; i++) { + + nir_def *offset = nir_imul_imm(b, nir_load_subgroup_invocation(b), + stride); + nir_deref_instr *memory_deref = + nir_build_deref_array(b, pointer, + nir_i2iN(b, nir_iadd_imm(b, offset, i), + pointer->def.bit_size)); + + if (load) { + results[i] = nir_load_deref(b, memory_deref); + } else { + nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); + nir_store_deref(b, memory_deref, src, 0x1); + } + } + + if (load) + nir_store_deref(b, slice, nir_vec(b, results, num_components), 0xffff); +} + +static void +lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, + struct lower_cmat_state *state) +{ + nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); + nir_def *src = nir_load_deref(b, nir_src_as_deref(intrin->src[1])); + nir_def *results[NIR_MAX_VEC_COMPONENTS]; + const unsigned num_components = glsl_get_vector_elements(dst_slice->type); + + for (unsigned i = 0; i < num_components; i++) { + nir_def *val = nir_channel(b, src, i); + results[i] = nir_build_alu1(b, nir_intrinsic_alu_op(intrin), val); + } + + nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), + nir_component_mask(num_components)); +} + +static void +lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, + struct lower_cmat_state *state) +{ + nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); + nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]); + nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]); + + nir_def *src_a = nir_load_deref(b, src_a_slice); + nir_def *src_b = nir_load_deref(b, src_b_slice); + nir_def *results[NIR_MAX_VEC_COMPONENTS]; + const unsigned num_components = glsl_get_vector_elements(dst_slice->type); + + ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); + ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice); + ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice); + + assert(dst_mat_type == src_a_mat_type); + assert(dst_mat_type == src_b_mat_type); + + for (unsigned i = 0; i < num_components; i++) { + nir_def *val_a = nir_channel(b, src_a, i); + nir_def *val_b = nir_channel(b, src_b, i); + + results[i] = nir_build_alu2(b, nir_intrinsic_alu_op(intrin), val_a, val_b); + } + + nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), + nir_component_mask(num_components)); +} + +static void +lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, + struct lower_cmat_state *state) +{ + nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); + nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); + nir_def *scalar = intrin->src[2].ssa; + + nir_def *src = nir_load_deref(b, src_slice); + nir_def *results[NIR_MAX_VEC_COMPONENTS]; + const unsigned num_components = glsl_get_vector_elements(dst_slice->type); + + ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice); + ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice); + assert(dst_mat_type == src_mat_type); + + for (unsigned i = 0; i < num_components; i++) { + nir_def *val = nir_channel(b, src, i); + + results[i] = nir_build_alu2(b, nir_intrinsic_alu_op(intrin), val, scalar); + } + + nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), + nir_component_mask(num_components)); +} + +static nir_deref_instr * +lower_cmat_deref(nir_builder *b, nir_deref_instr *deref, + struct lower_cmat_state *state) +{ + nir_deref_instr *parent = nir_deref_instr_parent(deref); + if (parent) { + assert(deref->deref_type == nir_deref_type_array); + parent = lower_cmat_deref(b, parent, state); + return nir_build_deref_array(b, parent, deref->arr.index.ssa); + } else { + assert(deref->deref_type == nir_deref_type_var); + assert(deref->var); + assert(glsl_type_is_cmat(glsl_without_array(deref->var->type))); + + struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var); + assert(entry); + return nir_build_deref_var(b, (nir_variable *)entry->data); + } +} + +static nir_def * +lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) +{ + struct lower_cmat_state *state = _state; + + if (instr->type == nir_instr_type_deref) { + nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state); + return &deref->def; + } + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_cmat_load: + case nir_intrinsic_cmat_store: + lower_cmat_load_store(b, intrin, state); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_construct: { + nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); + nir_def *src = intrin->src[1].ssa; + const unsigned num_components = glsl_get_vector_elements(slice->type); + + nir_store_deref(b, slice, nir_replicate(b, src, num_components), + nir_component_mask(num_components)); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + } + + case nir_intrinsic_cmat_unary_op: + lower_cmat_unary_op(b, intrin, state); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_binary_op: + lower_cmat_binary_op(b, intrin, state); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_scalar_op: + lower_cmat_scalar_op(b, intrin, state); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_length: { + const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin); + const struct glsl_type *mat_type = glsl_cmat_type(&desc); + const struct glsl_type *slice_type = get_slice_type(state, mat_type); + return nir_imm_intN_t(b, glsl_get_vector_elements(slice_type), 32); + } + + case nir_intrinsic_cmat_muladd: + /* FINISHME. */ + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_bitcast: + case nir_intrinsic_cmat_insert: + case nir_intrinsic_cmat_extract: + /* FINISHME. */ + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + case nir_intrinsic_cmat_copy: + nir_copy_deref(b, + nir_src_as_deref(intrin->src[0]), + nir_src_as_deref(intrin->src[1])); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + + default: + unreachable("invalid cooperative matrix intrinsic"); + } +} + +static void +create_slice_var(struct lower_cmat_state *state, nir_variable *var, + nir_function_impl *impl) +{ + // TODO: without array + const struct glsl_type *mat_type = glsl_without_array(var->type); + + assert(glsl_type_is_cmat(mat_type)); + assert((!impl && var->data.mode == nir_var_shader_temp) || + ( impl && var->data.mode == nir_var_function_temp)); + + const struct glsl_type *slice_type = get_slice_type(state, var->type); + const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name); + nir_variable *slice_var = impl ? + nir_local_variable_create(impl, slice_type, slice_name) : + nir_variable_create(state->shader, var->data.mode, slice_type, slice_name); + + _mesa_hash_table_insert(state->vars_to_slice, var, slice_var); + _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type); +} + +bool +brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size) +{ + void *temp_ctx = ralloc_context(NULL); + + struct lower_cmat_state state = { + .shader = shader, + .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx), + .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx), + .subgroup_size = subgroup_size, + }; + + /* Create a slice array for each variable and add a map from the original + * variable back to it, so it can be reached during lowering. + * + * TODO: Cooperative matrix inside struct? + */ + nir_foreach_variable_in_shader(var, shader) { + if (glsl_type_is_cmat(glsl_without_array(var->type))) + create_slice_var(&state, var, NULL); + } + nir_foreach_function(func, shader) { + nir_foreach_function_temp_variable(var, func->impl) { + if (glsl_type_is_cmat(glsl_without_array(var->type))) + create_slice_var(&state, var, func->impl); + } + } + + bool progress = nir_shader_lower_instructions(shader, + lower_cmat_filter, + lower_cmat_instr, + &state); + + ralloc_free(temp_ctx); + + return progress; +} diff --git a/src/intel/compiler/meson.build b/src/intel/compiler/meson.build index 200cc9859fa..5fd08abd4fb 100644 --- a/src/intel/compiler/meson.build +++ b/src/intel/compiler/meson.build @@ -88,6 +88,7 @@ libintel_compiler_files = files( 'brw_nir_blockify_uniform_loads.c', 'brw_nir_clamp_per_vertex_loads.c', 'brw_nir_lower_conversions.c', + 'brw_nir_lower_cooperative_matrix.c', 'brw_nir_lower_cs_intrinsics.c', 'brw_nir_lower_alpha_to_coverage.c', 'brw_nir_lower_intersection_shader.c',