mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-06-17 19:38:21 +02:00
radv: use load_deref_transpose_amd for transposed cooperative matrix loads
This requires that BDA is used or robust buffer access is disabled for cooperative matrix loads. No fossil-db changes. fossil-db (gfx1201, dEQP-VK.compute.pipeline.cooperative_matrix.*): Totals from 603 (37.15% of 1623) affected shaders: Instrs: 57422 -> 51212 (-10.81%); split: -11.99%, +1.17% CodeSize: 357444 -> 310688 (-13.08%); split: -13.70%, +0.62% VGPRs: 16668 -> 13188 (-20.88%); split: -21.53%, +0.65% Latency: 492820 -> 469600 (-4.71%); split: -4.82%, +0.11% InvThroughput: 63548 -> 55754 (-12.26%); split: -13.09%, +0.82% VClause: 1624 -> 1620 (-0.25%); split: -2.71%, +2.46% Copies: 2965 -> 3175 (+7.08%); split: -15.41%, +22.50% PreSGPRs: 6966 -> 5450 (-21.76%); split: -21.91%, +0.14% PreVGPRs: 7049 -> 5978 (-15.19%); split: -15.39%, +0.20% VALU: 27454 -> 24315 (-11.43%); split: -14.74%, +3.30% SALU: 5996 -> 6997 (+16.69%) VMEM: 6748 -> 4656 (-31.00%) SMEM: 1225 -> 1577 (+28.73%) VOPD: 45 -> 39 (-13.33%); split: +4.44%, -17.78% Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41653>
This commit is contained in:
parent
0cdb7594d7
commit
ca0496bc26
3 changed files with 85 additions and 8 deletions
|
|
@ -66,7 +66,8 @@ void radv_nir_lower_io(nir_shader *nir);
|
||||||
|
|
||||||
bool radv_nir_lower_io_to_mem(const struct radv_compiler_info *compiler_info, struct radv_shader_stage *stage);
|
bool radv_nir_lower_io_to_mem(const struct radv_compiler_info *compiler_info, struct radv_shader_stage *stage);
|
||||||
|
|
||||||
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
|
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level,
|
||||||
|
struct radv_shader_stage *stage, unsigned wave_size);
|
||||||
|
|
||||||
bool radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level);
|
bool radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
#include "nir_builder.h"
|
#include "nir_builder.h"
|
||||||
#include "radv_nir.h"
|
#include "radv_nir.h"
|
||||||
|
#include "radv_shader.h"
|
||||||
|
|
||||||
/* This pass lowers cooperative matrix.
|
/* This pass lowers cooperative matrix.
|
||||||
*
|
*
|
||||||
|
|
@ -43,6 +44,8 @@
|
||||||
typedef struct {
|
typedef struct {
|
||||||
enum amd_gfx_level gfx_level;
|
enum amd_gfx_level gfx_level;
|
||||||
unsigned wave_size;
|
unsigned wave_size;
|
||||||
|
bool ubo_robustness;
|
||||||
|
bool ssbo_robustness;
|
||||||
} lower_cmat_params;
|
} lower_cmat_params;
|
||||||
|
|
||||||
static unsigned
|
static unsigned
|
||||||
|
|
@ -241,6 +244,39 @@ lower_cmat_construct(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
get_load_tr_row_col(nir_builder *b, unsigned bit_size, nir_def **row, nir_def **col)
|
||||||
|
{
|
||||||
|
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||||
|
|
||||||
|
/* In wave64, the instruction only cares about the address for lanes 0-31. */
|
||||||
|
if (bit_size == 16) {
|
||||||
|
/*
|
||||||
|
* lane: 0..7 | 8..15 | 16..23 | 24..31
|
||||||
|
* row: 0..7 | 0..7 | 8..15 | 8..15
|
||||||
|
* column: 0 | 8 | 0 | 8
|
||||||
|
*/
|
||||||
|
*row = nir_imul_imm(b, nir_udiv_imm(b, lane_id, 16), 8);
|
||||||
|
*row = nir_iadd(b, *row, nir_iand_imm(b, lane_id, 7));
|
||||||
|
|
||||||
|
nir_def *odd8 = nir_inverse_ballot_imm(b, UINT64_C(0xff00ff00ff00ff00), b->shader->info.api_subgroup_size);
|
||||||
|
*col = nir_bcsel(b, odd8, nir_imm_int(b, 8), nir_imm_int(b, 0));
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
* lane: 0..3 | 4..7 | 8..11 | 12..15 | 16..19 | 20..23 | 24..27 | 28..31
|
||||||
|
* row: 0..3 | 0..3 | 4..7 | 4..7 | 8..11 | 8..11 | 12..15 | 12..15
|
||||||
|
* column: 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8
|
||||||
|
*/
|
||||||
|
assert(bit_size == 8);
|
||||||
|
|
||||||
|
*row = nir_imul_imm(b, nir_udiv_imm(b, lane_id, 8), 4);
|
||||||
|
*row = nir_iadd(b, *row, nir_iand_imm(b, lane_id, 3));
|
||||||
|
|
||||||
|
nir_def *odd4 = nir_inverse_ballot_imm(b, UINT64_C(0xf0f0f0f0f0f0f0f0), b->shader->info.api_subgroup_size);
|
||||||
|
*col = nir_bcsel(b, odd4, nir_imm_int(b, 8), nir_imm_int(b, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static bool
|
static bool
|
||||||
lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat_params *params)
|
lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat_params *params)
|
||||||
{
|
{
|
||||||
|
|
@ -249,16 +285,15 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
|
||||||
nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]);
|
nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]);
|
||||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
|
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
|
||||||
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
||||||
|
unsigned length = radv_nir_cmat_length(desc, params);
|
||||||
|
|
||||||
nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]);
|
nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]);
|
||||||
nir_def *stride = intr->src[2].ssa;
|
nir_def *stride = intr->src[2].ssa;
|
||||||
|
|
||||||
const uint32_t ptr_stride = glsl_get_bit_size(deref->type) / 8 * glsl_get_vector_elements(deref->type);
|
const uint32_t ptr_stride = glsl_get_bit_size(deref->type) / 8 * glsl_get_vector_elements(deref->type);
|
||||||
|
const unsigned idx_bits = deref->def.bit_size;
|
||||||
deref = nir_build_deref_cast(b, &deref->def, deref->modes, deref->type, ptr_stride);
|
deref = nir_build_deref_cast(b, &deref->def, deref->modes, deref->type, ptr_stride);
|
||||||
|
|
||||||
nir_def *local_idx = nir_load_subgroup_invocation(b);
|
|
||||||
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
|
|
||||||
|
|
||||||
bool load_acc_as_b = is_load && params->gfx_level < GFX11_7 && desc.use == GLSL_CMAT_USE_ACCUMULATOR &&
|
bool load_acc_as_b = is_load && params->gfx_level < GFX11_7 && desc.use == GLSL_CMAT_USE_ACCUMULATOR &&
|
||||||
radv_nir_cmat_bits(desc) == 8 && params->wave_size == 32 &&
|
radv_nir_cmat_bits(desc) == 8 && params->wave_size == 32 &&
|
||||||
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||||
|
|
@ -270,7 +305,46 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
|
||||||
layout =
|
layout =
|
||||||
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||||
|
|
||||||
unsigned length = radv_nir_cmat_length(desc, params);
|
bool use_tr_load = params->gfx_level >= GFX12 && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR && is_load &&
|
||||||
|
radv_nir_cmat_bits(desc) < 32 &&
|
||||||
|
(nir_deref_mode_is(deref, nir_var_mem_global) ||
|
||||||
|
(nir_deref_mode_is(deref, nir_var_mem_ubo) && !params->ubo_robustness) ||
|
||||||
|
(nir_deref_mode_is(deref, nir_var_mem_ssbo) && !params->ssbo_robustness));
|
||||||
|
|
||||||
|
if (use_tr_load) {
|
||||||
|
assert(!load_acc_as_b);
|
||||||
|
|
||||||
|
const unsigned elem_bits = radv_nir_cmat_bits(desc);
|
||||||
|
nir_def *row, *col;
|
||||||
|
get_load_tr_row_col(b, elem_bits, &row, &col);
|
||||||
|
col = nir_u2uN(b, col, idx_bits);
|
||||||
|
row = nir_u2uN(b, nir_imul(b, row, stride), idx_bits);
|
||||||
|
|
||||||
|
deref = nir_build_deref_ptr_as_array(b, deref, row);
|
||||||
|
deref = nir_build_deref_cast(b, &deref->def, deref->modes, glsl_scalar_type(desc.element_type), elem_bits / 8);
|
||||||
|
deref = nir_build_deref_ptr_as_array(b, deref, col);
|
||||||
|
|
||||||
|
/* Convert buffer deref to a global one. */
|
||||||
|
if (nir_deref_mode_is_one_of(deref, nir_var_mem_ssbo | nir_var_mem_ubo)) {
|
||||||
|
nir_def *descriptor = nir_ssbo_descriptor_amd(b, &deref->def);
|
||||||
|
nir_def *addr_lo = nir_channel(b, descriptor, 0);
|
||||||
|
nir_def *addr_hi = nir_extract_i16(b, nir_channel(b, descriptor, 1), nir_imm_int(b, 0));
|
||||||
|
nir_def *addr = nir_pack_64_2x32_split(b, addr_lo, addr_hi);
|
||||||
|
|
||||||
|
nir_def *offset = nir_channel(b, &deref->def, 2);
|
||||||
|
addr = nir_iadd_nuw(b, addr, nir_u2u64(b, offset));
|
||||||
|
deref = nir_build_deref_cast(b, addr, nir_var_mem_global, deref->type, elem_bits / 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_def *mat = nir_load_deref_transpose_amd(b, length, elem_bits, &deref->def);
|
||||||
|
nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components));
|
||||||
|
nir_instr_remove(&intr->instr);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
nir_def *local_idx = nir_load_subgroup_invocation(b);
|
||||||
|
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
|
||||||
|
|
||||||
unsigned mul = radv_nir_cmat_length_mul(desc, params);
|
unsigned mul = radv_nir_cmat_length_mul(desc, params);
|
||||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16;
|
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16;
|
||||||
nir_def *vars[16];
|
nir_def *vars[16];
|
||||||
|
|
@ -289,7 +363,6 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
|
||||||
vars[i] = nir_channel(b, src, i);
|
vars[i] = nir_channel(b, src, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned idx_bits = deref->def.bit_size;
|
|
||||||
nir_def *base_row = radv_get_base_row(b, desc, params, local_idx);
|
nir_def *base_row = radv_get_base_row(b, desc, params, local_idx);
|
||||||
|
|
||||||
/* VUID-RuntimeSpirv-OpCooperativeMatrixLoadKHR-08986:
|
/* VUID-RuntimeSpirv-OpCooperativeMatrixLoadKHR-08986:
|
||||||
|
|
@ -1064,7 +1137,8 @@ lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
bool
|
||||||
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
|
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, struct radv_shader_stage *stage,
|
||||||
|
unsigned wave_size)
|
||||||
{
|
{
|
||||||
bool progress = false;
|
bool progress = false;
|
||||||
|
|
||||||
|
|
@ -1074,6 +1148,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
||||||
const lower_cmat_params params = {
|
const lower_cmat_params params = {
|
||||||
.gfx_level = gfx_level,
|
.gfx_level = gfx_level,
|
||||||
.wave_size = wave_size,
|
.wave_size = wave_size,
|
||||||
|
.ubo_robustness = stage->key.coop_matrix_uniform_robustness,
|
||||||
|
.ssbo_robustness = stage->key.coop_matrix_storage_robustness,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
|
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);
|
||||||
|
|
|
||||||
|
|
@ -579,7 +579,7 @@ radv_shader_spirv_to_nir(const struct radv_compiler_info *compiler_info, struct
|
||||||
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
|
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
NIR_PASS(progress, nir, radv_nir_lower_cooperative_matrix, compiler_info->ac->gfx_level,
|
NIR_PASS(progress, nir, radv_nir_lower_cooperative_matrix, compiler_info->ac->gfx_level, stage,
|
||||||
nir->info.max_subgroup_size);
|
nir->info.max_subgroup_size);
|
||||||
if (progress) {
|
if (progress) {
|
||||||
NIR_PASS(_, nir, nir_opt_dce);
|
NIR_PASS(_, nir, nir_opt_dce);
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue