diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index f00323c4768..cc1a3ac2ea6 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -13,7 +13,7 @@ * to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are * always vectors of 16 elements. * - * On GFX12, there is no data replication and the matrices layout is described + * On GFX11.7+, there is no data replication and the matrices layout is described * as below: * * Wave32: @@ -54,7 +54,7 @@ radv_nir_cmat_bits(struct glsl_cmat_description desc) static unsigned radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params) { - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { assert(desc.cols == 16 && desc.rows == 16); return 256 / params->wave_size; } else if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) { @@ -67,7 +67,7 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params static unsigned radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params) { - if (params->gfx_level >= GFX12 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) { + if (params->gfx_level >= GFX11_7 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) { return 1; } else { /* For GFX11 C matrices we have 1 VGPR per element even if the element type is @@ -148,7 +148,7 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower { nir_def *base_row; - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { base_row = nir_udiv_imm(b, local_idx, 16); if (params->wave_size == 64) { @@ -259,7 +259,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma nir_def *local_idx = nir_load_subgroup_invocation(b); nir_def *inner_idx = nir_iand_imm(b, local_idx, 15); - bool load_acc_as_b = is_load && params->gfx_level < GFX12 && desc.use == GLSL_CMAT_USE_ACCUMULATOR && + bool load_acc_as_b = is_load && params->gfx_level < GFX11_7 && desc.use == GLSL_CMAT_USE_ACCUMULATOR && radv_nir_cmat_bits(desc) == 8 && params->wave_size == 32 && layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; if (load_acc_as_b) @@ -281,7 +281,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma vars[i] = nir_undef(b, 1, radv_nir_cmat_bits(desc)); } } else { - if (params->gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) + if (params->gfx_level < GFX11_7 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) nir_push_if(b, nir_ilt_imm(b, local_idx, 16)); nir_def *src = radv_nir_load_cmat(b, params, &cmat_deref->def); @@ -302,7 +302,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma if (layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR) align_mul = MIN2(16, radv_nir_cmat_bits(desc) * desc.rows / 8); - if (params->gfx_level >= GFX12) + if (params->gfx_level >= GFX11_7) align_mul /= params->wave_size / 16; else if (desc.use == GLSL_CMAT_USE_ACCUMULATOR) align_mul = 0; @@ -312,7 +312,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma nir_def *row_offset; uint32_t row_iter; - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { row_iter = i; } else { row_iter = i * lanes_per_iter / 16; @@ -366,7 +366,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma } nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components)); - } else if (params->gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) { + } else if (params->gfx_level < GFX11_7 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) { nir_pop_if(b, NULL); } nir_instr_remove(&intr->instr); @@ -445,7 +445,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ { if (src_use == dst_use) return src; - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) return src; if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) @@ -467,7 +467,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ components[i] = nir_channel(b, src, i); if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) { - assert(params->gfx_level < GFX12); + assert(params->gfx_level < GFX11_7); nir_def *tmp[NIR_MAX_VEC_COMPONENTS]; if (src->bit_size == 32) { @@ -521,7 +521,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ assert(num_comps == 16); } else if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) { - assert(params->gfx_level < GFX12); + assert(params->gfx_level < GFX11_7); assert(num_comps == 16); if (src->bit_size == 32) { for (unsigned keep32 = 0; keep32 < ((params->wave_size == 64) ? 2 : 1); keep32++) { @@ -585,9 +585,9 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ } } - assert(num_comps == 16 || params->gfx_level >= GFX12); + assert(num_comps == 16 || params->gfx_level >= GFX11_7); - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { /* One component contains 2/4 rows in wave32/64, so we must transpose inside it. */ for (int cross32 = params->wave_size == 64; cross32 >= 0; cross32--) { uint64_t even = cross32 ? 0xf0f0f0f00f0f0f0f : 0xff0000ffff0000ff; @@ -901,7 +901,7 @@ lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call, const lowe nir_def *low16 = nir_inverse_ballot_imm(b, 0xffff0000ffff, params->wave_size); for (unsigned m = 0; m < 4; m++) { for (unsigned i = 0; i < length / mul / 2; i++) { - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { /* The neighboring row is in the VGPR next to us */ nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]); src_components[m][i] = nir_load_deref(b, qd_tmp_deref); @@ -946,7 +946,7 @@ lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call, const lowe for (unsigned i = 0; i < length / mul; i++) vars[i] = nir_lane_permute_16_amd(b, vars[i], perm_low, perm_high); - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { /* For GFX12, we still have to swap the row(s) in upper half coming from the bottom two * matrices with low row(s) in the from other two matrices. */ @@ -1021,7 +1021,7 @@ lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr); uint32_t row_iter; - if (params->gfx_level >= GFX12) { + if (params->gfx_level >= GFX11_7) { row_iter = i; } else { row_iter = i * lanes_per_iter / 16; @@ -1205,7 +1205,7 @@ apply_component_mods(nir_scalar *comp, unsigned num_comps, unsigned stride, nir_ static bool opt_cmat_modifiers(nir_builder *b, nir_intrinsic_instr *intrin, enum amd_gfx_level gfx_level, unsigned src_idx) { - unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX12 ? 2 : 1; + unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX11_7 ? 2 : 1; nir_scalar comp[NIR_MAX_VEC_COMPONENTS] = {0}; nir_def *src = intrin->src[src_idx].ssa;