From ad611adeb772ca82460eea328acf57e6333304d9 Mon Sep 17 00:00:00 2001 From: Samuel Pitoiset Date: Wed, 22 Jan 2025 03:46:01 -0800 Subject: [PATCH] radv/nir: add a struct for parameters to cooperative matrix lowering Signed-off-by: Samuel Pitoiset Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 69 +++++++++++-------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index e3321a1c65a..884ceaa97ca 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -7,12 +7,16 @@ #include "nir_builder.h" #include "radv_nir.h" +typedef struct { + unsigned wave_size; +} lower_cmat_params; + static unsigned -radv_nir_cmat_length(struct glsl_cmat_description desc, unsigned wave_size) +radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params) { return desc.use != GLSL_CMAT_USE_ACCUMULATOR ? 16 - : (desc.cols * desc.rows / wave_size * 32 / glsl_base_type_bit_size(desc.element_type)); + : (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type)); } /* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement @@ -32,28 +36,29 @@ radv_nir_cmat_bits(struct glsl_cmat_description desc) } static nir_def * -radv_nir_load_cmat(nir_builder *b, unsigned wave_size, nir_def *src) +radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src) { nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type); - return nir_build_load_deref(b, radv_nir_cmat_length(desc, wave_size), glsl_base_type_bit_size(desc.element_type), - src, 0); + return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src, + 0); } static const struct glsl_type * -radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, unsigned wave_size) +radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_table *type_map, + const lower_cmat_params *params) { struct hash_entry *entry = _mesa_hash_table_search(type_map, orig_type); if (entry) { return entry->data; } else if (glsl_type_is_cmat(orig_type)) { struct glsl_cmat_description desc = *glsl_get_cmat_description(orig_type); - unsigned length = radv_nir_cmat_length(desc, wave_size); + unsigned length = radv_nir_cmat_length(desc, params); return glsl_vector_type(desc.element_type, length); } else if (glsl_type_is_array(orig_type)) { const struct glsl_type *elem_type = glsl_get_array_element(orig_type); - const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, wave_size); + const struct glsl_type *new_elem_type = radv_nir_translate_matrix_type(elem_type, type_map, params); if (elem_type == new_elem_type) return orig_type; @@ -65,7 +70,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta bool change = false; for (unsigned i = 0; i < num_fields; ++i) { const struct glsl_type *field_type = glsl_get_struct_field(orig_type, i); - const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, wave_size); + const struct glsl_type *new_field_type = radv_nir_translate_matrix_type(field_type, type_map, params); if (field_type != new_field_type) { change = true; @@ -81,7 +86,7 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta for (unsigned i = 0; i < num_fields; ++i) { fields[i] = *glsl_get_struct_field_data(orig_type, i); - fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, wave_size); + fields[i].type = radv_nir_translate_matrix_type(fields[i].type, type_map, params); } const struct glsl_type *ret = @@ -102,11 +107,15 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) if (!shader->info.cs.has_cooperative_matrix) return false; + const lower_cmat_params params = { + .wave_size = wave_size, + }; + struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions); struct hash_table *type_map = _mesa_pointer_hash_table_create(NULL); nir_foreach_variable_with_modes (var, shader, nir_var_shader_temp) { - const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size); + const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, ¶ms); if (new_type != var->type) { var->type = new_type; progress = true; @@ -114,7 +123,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) } nir_foreach_function_temp_variable (var, func->impl) { - const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, wave_size); + const struct glsl_type *new_type = radv_nir_translate_matrix_type(var->type, type_map, ¶ms); if (new_type != var->type) { var->type = new_type; progress = true; @@ -134,7 +143,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) switch (intr->intrinsic) { case nir_intrinsic_cmat_length: { struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr); - unsigned len = radv_nir_cmat_length(desc, wave_size) / radv_nir_cmat_length_mul(desc); + unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc); nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len)); nir_instr_remove(instr); progress = true; @@ -143,7 +152,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) case nir_intrinsic_cmat_extract: { nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type); - nir_def *src0 = radv_nir_load_cmat(&b, wave_size, intr->src[0].ssa); + nir_def *src0 = radv_nir_load_cmat(&b, ¶ms, intr->src[0].ssa); nir_def *index = intr->src[1].ssa; index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc)); @@ -156,7 +165,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) break; } case nir_intrinsic_cmat_insert: { - nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); + nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa); nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); nir_def *index = intr->src[3].ssa; @@ -174,7 +183,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); nir_def *elem = intr->src[1].ssa; - nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, wave_size)); + nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, ¶ms)); nir_store_deref(&b, dst_deref, r, nir_component_mask(r->num_components)); nir_instr_remove(instr); @@ -197,9 +206,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; - unsigned length = radv_nir_cmat_length(desc, wave_size); + unsigned length = radv_nir_cmat_length(desc, ¶ms); unsigned mul = radv_nir_cmat_length_mul(desc); - unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16; + unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; nir_def *vars[16]; if (mul > 1) { for (unsigned i = 0; i < length; ++i) @@ -250,7 +259,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type); - src = radv_nir_load_cmat(&b, wave_size, src); + src = radv_nir_load_cmat(&b, ¶ms, src); nir_def *local_idx = nir_load_subgroup_invocation(&b); @@ -264,9 +273,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; - unsigned length = radv_nir_cmat_length(desc, wave_size); + unsigned length = radv_nir_cmat_length(desc, ¶ms); unsigned mul = radv_nir_cmat_length_mul(desc); - unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? wave_size : 16; + unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; nir_def *vars[16]; for (unsigned i = 0; i < length; ++i) vars[i] = nir_channel(&b, src, i); @@ -307,9 +316,9 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) break; } case nir_intrinsic_cmat_muladd: { - nir_def *A = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); - nir_def *B = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); - nir_def *C = radv_nir_load_cmat(&b, wave_size, intr->src[3].ssa); + nir_def *A = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); + nir_def *B = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa); + nir_def *C = radv_nir_load_cmat(&b, ¶ms, intr->src[3].ssa); nir_def *ret; ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = nir_intrinsic_saturate(intr), @@ -326,7 +335,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); - nir_def *src = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); + nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); if (glsl_base_type_bit_size(src_desc.element_type) == 16 && @@ -356,7 +365,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) break; } case nir_intrinsic_cmat_scalar_op: { - nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); + nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, @@ -366,8 +375,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) break; } case nir_intrinsic_cmat_binary_op: { - nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); - nir_def *src2 = radv_nir_load_cmat(&b, wave_size, intr->src[2].ssa); + nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); + nir_def *src2 = radv_nir_load_cmat(&b, ¶ms, intr->src[2].ssa); nir_op op = nir_intrinsic_alu_op(intr); nir_def *ret = nir_build_alu2(&b, op, src1, src2); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, @@ -377,7 +386,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) break; } case nir_intrinsic_cmat_bitcast: { - nir_def *src1 = radv_nir_load_cmat(&b, wave_size, intr->src[1].ssa); + nir_def *src1 = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); nir_store_deref(&b, nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, nir_component_mask(src1->num_components)); nir_instr_remove(instr); @@ -397,7 +406,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) } case nir_instr_type_deref: { nir_deref_instr *deref = nir_instr_as_deref(instr); - const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, wave_size); + const struct glsl_type *new_type = radv_nir_translate_matrix_type(deref->type, type_map, ¶ms); if (new_type != deref->type) { deref->type = new_type; progress = true;