ir3: Move the compute shader threadsize forcing earlier.

With this, we can look at real_wavesize while running NIR passes and know
if we have to be doubled because of the shader info coming in.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37245>
This commit is contained in:
Emma Anholt 2025-09-08 14:37:11 -07:00 committed by Marge Bot
parent 5a09abe890
commit d5cb38e457
2 changed files with 50 additions and 31 deletions

View file

@ -220,37 +220,7 @@ ir3_should_double_threadsize(struct ir3_shader_variant *v, unsigned regs_count)
switch (v->type) { switch (v->type) {
case MESA_SHADER_KERNEL: case MESA_SHADER_KERNEL:
case MESA_SHADER_COMPUTE: { case MESA_SHADER_COMPUTE:
unsigned threads_per_wg =
v->local_size[0] * v->local_size[1] * v->local_size[2];
/* If the workgroups fit in the base threadsize, then doubling would just
* leave us with an unused second half of each wave for no gain (The HW
* can't pack multiple workgroups into a wave, because the workgroups
* might make different barrier choices).
*/
if (!v->local_size_variable) {
if (threads_per_wg <= compiler->threadsize_base)
return false;
}
/* For a5xx, if the workgroup size is greater than the maximum number
* of threads per core with 32 threads per wave (512) then we have to
* use the doubled threadsize because otherwise the workgroup wouldn't
* fit. For smaller workgroup sizes, we follow the blob and use the
* smaller threadsize.
*
* For a6xx, because threadsize_base is bumped to 64, we don't have to
* worry about the workgroup fitting.
*/
if (compiler->gen < 6) {
return v->local_size_variable ||
threads_per_wg >
compiler->threadsize_base * compiler->max_waves;
}
}
FALLTHROUGH;
case MESA_SHADER_FRAGMENT: { case MESA_SHADER_FRAGMENT: {
/* One of the limits on maximum waves of the shader running in parallel is /* One of the limits on maximum waves of the shader running in parallel is
* the register count used in the shader compared to the hardware's * the register count used in the shader compared to the hardware's

View file

@ -1113,6 +1113,53 @@ atomic_supported(const nir_instr * instr, const void * data)
return nir_instr_as_intrinsic(instr)->def.bit_size != 64; return nir_instr_as_intrinsic(instr)->def.bit_size != 64;
} }
/**
* Filters the real_wavesize that was set based on API requirements, to an
* appopriate value given hardware limits and the NIR shader we get.
*
* The final wavesize in the SINGLE_OR_DOUBLE case will be determined later
* based on register allocation.
*/
static void
ir3_nir_set_threadsize(struct ir3_shader_variant *v, const nir_shader *s)
{
if (v->shader_options.real_wavesize != IR3_SINGLE_OR_DOUBLE)
return;
if (mesa_shader_stage_is_compute(v->type)) {
struct ir3_compiler *compiler = v->compiler;
const shader_info *info = &s->info;
unsigned threads_per_wg = info->workgroup_size[0] *
info->workgroup_size[1] *
info->workgroup_size[2];
/* If the workgroups fit in the base threadsize, then doubling would just
* leave us with an unused second half of each wave for no gain (the HW
* can't pack multiple workgroups into a wave, because the workgroups
* might make different barrier choices).
*/
if (!info->workgroup_size_variable) {
if (threads_per_wg <= compiler->threadsize_base)
v->shader_options.real_wavesize = IR3_SINGLE_ONLY;
}
/* For a5xx, if the workgroup size is greater than the maximum number
* of threads per core with 32 threads per wave (512) then we have to
* use the doubled threadsize because otherwise the workgroup wouldn't
* fit. For smaller workgroup sizes, we follow the blob and use the
* smaller threadsize.
*
* For a6xx, because threadsize_base is bumped to 64, we don't have to
* worry about the workgroup fitting.
*/
if (compiler->gen < 6 &&
(info->workgroup_size_variable ||
threads_per_wg > compiler->threadsize_base * compiler->max_waves)) {
v->shader_options.real_wavesize = IR3_DOUBLE_ONLY;
};
}
}
void void
ir3_nir_lower_variant(struct ir3_shader_variant *so, ir3_nir_lower_variant(struct ir3_shader_variant *so,
const struct ir3_shader_nir_options *options, const struct ir3_shader_nir_options *options,
@ -1126,6 +1173,8 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so,
mesa_logi("----------------------"); mesa_logi("----------------------");
} }
ir3_nir_set_threadsize(so, s);
bool progress = false; bool progress = false;
progress |= OPT(s, nir_lower_io_to_scalar, nir_var_mem_ssbo, progress |= OPT(s, nir_lower_io_to_scalar, nir_var_mem_ssbo,