mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-30 18:00:24 +01:00
radv/nir/lower_cmat: share cmat_load/cmat_store code
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35633>
This commit is contained in:
parent
f3f67823c4
commit
ed2ecf9ef8
1 changed files with 30 additions and 82 deletions
|
|
@ -418,89 +418,18 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_load: {
|
||||
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
||||
|
||||
nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
|
||||
nir_def *stride = intr->src[2].ssa;
|
||||
|
||||
nir_def *local_idx = nir_load_subgroup_invocation(&b);
|
||||
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
|
||||
|
||||
/* A input is transposed */
|
||||
if (desc.use == GLSL_CMAT_USE_A)
|
||||
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
|
||||
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned length = radv_nir_cmat_length(desc, ¶ms);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
if (mul > 1) {
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
if (i % mul != 0)
|
||||
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
|
||||
}
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx);
|
||||
|
||||
for (unsigned i = 0; i < length / mul; ++i) {
|
||||
nir_def *col_offset = inner_idx;
|
||||
nir_def *row_offset;
|
||||
uint32_t row_iter;
|
||||
|
||||
if (gfx_level >= GFX12) {
|
||||
row_iter = i;
|
||||
} else {
|
||||
row_iter = i * lanes_per_iter / 16;
|
||||
}
|
||||
|
||||
row_offset = nir_iadd_imm(&b, base_row, row_iter);
|
||||
|
||||
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
|
||||
nir_def *tmp = col_offset;
|
||||
col_offset = row_offset;
|
||||
row_offset = tmp;
|
||||
}
|
||||
|
||||
col_offset = nir_imul(&b, col_offset, stride);
|
||||
|
||||
col_offset = nir_u2uN(&b, col_offset, idx_bits);
|
||||
row_offset = nir_u2uN(&b, row_offset, idx_bits);
|
||||
|
||||
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
|
||||
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
|
||||
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
|
||||
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
|
||||
|
||||
vars[i * mul] = nir_load_deref(&b, iter_deref);
|
||||
}
|
||||
|
||||
nir_def *mat = nir_vec(&b, vars, length);
|
||||
nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
}
|
||||
case nir_intrinsic_cmat_load:
|
||||
case nir_intrinsic_cmat_store: {
|
||||
const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load;
|
||||
|
||||
nir_deref_instr *cmat_deref = nir_instr_as_deref(intr->src[!is_load].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
|
||||
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
||||
|
||||
nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
nir_def *src = intr->src[1].ssa;
|
||||
nir_deref_instr *deref = nir_instr_as_deref(intr->src[is_load].ssa->parent_instr);
|
||||
nir_def *stride = intr->src[2].ssa;
|
||||
|
||||
nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
|
||||
src = radv_nir_load_cmat(&b, ¶ms, src);
|
||||
|
||||
nir_def *local_idx = nir_load_subgroup_invocation(&b);
|
||||
|
||||
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
|
||||
|
||||
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
|
||||
|
||||
/* A input is transposed */
|
||||
|
|
@ -512,8 +441,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
vars[i] = nir_channel(&b, src, i);
|
||||
if (is_load) {
|
||||
if (mul > 1) {
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
if (i % mul != 0)
|
||||
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
|
||||
}
|
||||
} else {
|
||||
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
|
||||
|
||||
nir_def *src = radv_nir_load_cmat(&b, ¶ms, &cmat_deref->def);
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
vars[i] = nir_channel(&b, src, i);
|
||||
}
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx);
|
||||
|
|
@ -547,12 +488,19 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
|
||||
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
|
||||
|
||||
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
|
||||
if (is_load) {
|
||||
vars[i * mul] = nir_load_deref(&b, iter_deref);
|
||||
} else {
|
||||
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
if (is_load) {
|
||||
nir_def *mat = nir_vec(&b, vars, length);
|
||||
nir_store_deref(&b, cmat_deref, mat, nir_component_mask(mat->num_components));
|
||||
} else if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_pop_if(&b, NULL);
|
||||
|
||||
}
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
break;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue