anv: Make use of anv_shader_get_scratch_surf() in genX_cmd_compute.c

genX_cmd_compute.c has 2 places that is had a code very similar to
anv_shader_get_scratch_surf() but we could not make use of this function without
change it parameters.

Now it takes the shader stage and the total_scratch instead of anv_shader because
cmd_buffer_trace_rays() don't have a shader.

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Signed-off-by: José Roberto de Souza <jose.souza@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40832>
This commit is contained in:
José Roberto de Souza 2026-04-02 13:49:10 -07:00 committed by Marge Bot
parent fd420e80e2
commit a69e02d97c
4 changed files with 60 additions and 46 deletions

View file

@ -1486,20 +1486,19 @@ anv_scratch_pool_get_surf(struct anv_device *device,
uint32_t
anv_shader_get_scratch_surf(struct anv_batch *batch,
struct anv_device *device,
struct anv_shader *shader,
mesa_shader_stage stage,
uint32_t total_scratch,
bool protected)
{
if (shader->prog_data->total_scratch == 0)
if (total_scratch == 0)
return 0;
struct anv_scratch_pool *pool = protected ?
&device->protected_scratch_pool : &device->scratch_pool;
struct anv_bo *bo =
anv_scratch_pool_alloc(device, pool, shader->vk.stage,
shader->prog_data->total_scratch);
anv_scratch_pool_alloc(device, pool, stage, total_scratch);
anv_reloc_list_add_bo(batch->relocs, bo);
uint32_t ret = anv_scratch_pool_get_surf(
device, pool, shader->prog_data->total_scratch);
uint32_t ret = anv_scratch_pool_get_surf(device, pool, total_scratch);
return ret >> ANV_SCRATCH_SPACE_SHIFT;
}

View file

@ -1425,7 +1425,8 @@ struct anv_shader {
uint32_t anv_shader_get_scratch_surf(struct anv_batch *batch,
struct anv_device *device,
struct anv_shader *shader,
mesa_shader_stage stage,
uint32_t total_scratch,
bool protected);
extern struct vk_device_shader_ops anv_device_shader_ops;

View file

@ -55,19 +55,12 @@ genX(cmd_buffer_ensure_cfe_state)(struct anv_cmd_buffer *cmd_buffer,
anv_batch_emit(&cmd_buffer->batch, GENX(CFE_STATE), cfe) {
cfe.MaximumNumberofThreads = devinfo->max_cs_threads * devinfo->subslice_total;
uint32_t scratch_surf;
struct anv_scratch_pool *scratch_pool =
(cmd_buffer->vk.pool->flags & VK_COMMAND_POOL_CREATE_PROTECTED_BIT) ?
&cmd_buffer->device->protected_scratch_pool :
&cmd_buffer->device->scratch_pool;
struct anv_bo *scratch_bo =
anv_scratch_pool_alloc(cmd_buffer->device, scratch_pool,
MESA_SHADER_COMPUTE,
total_scratch);
anv_reloc_list_add_bo(cmd_buffer->batch.relocs, scratch_bo);
scratch_surf = anv_scratch_pool_get_surf(cmd_buffer->device, scratch_pool,
total_scratch);
cfe.ScratchSpaceBuffer = scratch_surf >> ANV_SCRATCH_SPACE_SHIFT;
const bool protected = cmd_buffer->vk.pool->flags & VK_COMMAND_POOL_CREATE_PROTECTED_BIT;
cfe.ScratchSpaceBuffer = anv_shader_get_scratch_surf(&cmd_buffer->batch,
cmd_buffer->device,
MESA_SHADER_COMPUTE,
total_scratch,
protected);
#if GFX_VER >= 20
switch (cmd_buffer->device->physical->instance->stack_ids) {
case 256: cfe.StackIDControl = StackIDs256; break;
@ -1303,18 +1296,11 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
btd.PerDSSMemoryBackedBufferSize = 6;
btd.MemoryBackedBufferBasePointer = (struct anv_address) { .bo = device->btd_fifo_bo };
if (rt->scratch_size > 0) {
struct anv_bo *scratch_bo =
anv_scratch_pool_alloc(device,
&device->scratch_pool,
MESA_SHADER_COMPUTE,
rt->scratch_size);
anv_reloc_list_add_bo(cmd_buffer->batch.relocs,
scratch_bo);
uint32_t scratch_surf =
anv_scratch_pool_get_surf(cmd_buffer->device,
&device->scratch_pool,
rt->scratch_size);
btd.ScratchSpaceBuffer = scratch_surf >> ANV_SCRATCH_SPACE_SHIFT;
btd.ScratchSpaceBuffer = anv_shader_get_scratch_surf(&cmd_buffer->batch,
cmd_buffer->device,
MESA_SHADER_COMPUTE,
rt->scratch_size,
false);;
}
#if INTEL_NEEDS_WA_14017794102 || INTEL_NEEDS_WA_14023061436
btd.BTDMidthreadpreemption = false;

View file

@ -581,7 +581,9 @@ emit_vs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, vs.vs, vs_dwords, GENX(3DSTATE_VS), vs) {
#if GFX_VERx10 >= 125
vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
vs.PerThreadScratchSpace = get_scratch_space(shader);
vs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -591,7 +593,9 @@ emit_vs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, vs.vs_protected,
vs_dwords, GENX(3DSTATE_VS), vs) {
#if GFX_VERx10 >= 125
vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
#else
vs.PerThreadScratchSpace = get_scratch_space(shader);
vs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -659,7 +663,9 @@ emit_hs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, hs.hs, hs_dwords, GENX(3DSTATE_HS), hs) {
#if GFX_VERx10 >= 125
hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
hs.PerThreadScratchSpace = get_scratch_space(shader);
hs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -669,7 +675,9 @@ emit_hs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, hs.hs_protected,
hs_dwords, GENX(3DSTATE_HS), hs) {
#if GFX_VERx10 >= 125
hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
hs.PerThreadScratchSpace = get_scratch_space(shader);
hs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -753,7 +761,9 @@ emit_ds_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ds.ds, ds_dwords, GENX(3DSTATE_DS), ds) {
#if GFX_VERx10 >= 125
ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
ds.PerThreadScratchSpace = get_scratch_space(shader);
ds.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -763,7 +773,9 @@ emit_ds_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ds.ds_protected,
ds_dwords, GENX(3DSTATE_DS), ds) {
#if GFX_VERx10 >= 125
ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
#else
ds.PerThreadScratchSpace = get_scratch_space(shader);
ds.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -827,7 +839,9 @@ emit_gs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, gs.gs, gs_dwords, GENX(3DSTATE_GS), gs) {
#if GFX_VERx10 >= 125
gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
gs.PerThreadScratchSpace = get_scratch_space(shader);
gs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -837,7 +851,9 @@ emit_gs_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, gs.gs_protected,
gs_dwords, GENX(3DSTATE_GS), gs) {
#if GFX_VERx10 >= 125
gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
#else
gs.PerThreadScratchSpace = get_scratch_space(shader);
gs.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -867,12 +883,16 @@ emit_task_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ts.control,
task_control_dwords, GENX(3DSTATE_TASK_CONTROL), tc) {
tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
}
if (device_needs_protected(device)) {
anv_shader_emit_merge(batch, shader, ts.control_protected,
task_control_dwords, GENX(3DSTATE_TASK_CONTROL), tc) {
tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
}
}
@ -943,12 +963,16 @@ emit_mesh_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ms.control,
mesh_control_dwords, GENX(3DSTATE_MESH_CONTROL), mc) {
mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
}
if (device_needs_protected(device)) {
anv_shader_emit_merge(batch, shader, ms.control_protected,
mesh_control_dwords, GENX(3DSTATE_MESH_CONTROL), mc) {
mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
}
}
@ -1067,7 +1091,9 @@ emit_ps_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ps.ps, ps_dwords, GENX(3DSTATE_PS), ps) {
#if GFX_VERx10 >= 125
ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false);
ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
false);
#else
ps.PerThreadScratchSpace = get_scratch_space(shader);
ps.ScratchSpaceBasePointer = get_scratch_address(device, shader);
@ -1077,7 +1103,9 @@ emit_ps_shader(struct anv_batch *batch,
anv_shader_emit_merge(batch, shader, ps.ps_protected,
ps_dwords, GENX(3DSTATE_PS), ps) {
#if GFX_VERx10 >= 125
ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true);
ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage,
shader->prog_data->total_scratch,
true);
#else
ps.PerThreadScratchSpace = get_scratch_space(shader);
ps.ScratchSpaceBasePointer = get_scratch_address(device, shader);