diff --git a/src/intel/vulkan/anv_allocator.c b/src/intel/vulkan/anv_allocator.c index 5ece4805f11..6aeb1fc835a 100644 --- a/src/intel/vulkan/anv_allocator.c +++ b/src/intel/vulkan/anv_allocator.c @@ -1486,20 +1486,19 @@ anv_scratch_pool_get_surf(struct anv_device *device, uint32_t anv_shader_get_scratch_surf(struct anv_batch *batch, struct anv_device *device, - struct anv_shader *shader, + mesa_shader_stage stage, + uint32_t total_scratch, bool protected) { - if (shader->prog_data->total_scratch == 0) + if (total_scratch == 0) return 0; struct anv_scratch_pool *pool = protected ? &device->protected_scratch_pool : &device->scratch_pool; struct anv_bo *bo = - anv_scratch_pool_alloc(device, pool, shader->vk.stage, - shader->prog_data->total_scratch); + anv_scratch_pool_alloc(device, pool, stage, total_scratch); anv_reloc_list_add_bo(batch->relocs, bo); - uint32_t ret = anv_scratch_pool_get_surf( - device, pool, shader->prog_data->total_scratch); + uint32_t ret = anv_scratch_pool_get_surf(device, pool, total_scratch); return ret >> ANV_SCRATCH_SPACE_SHIFT; } diff --git a/src/intel/vulkan/anv_private.h b/src/intel/vulkan/anv_private.h index 9d4b202c598..747132be859 100644 --- a/src/intel/vulkan/anv_private.h +++ b/src/intel/vulkan/anv_private.h @@ -1425,7 +1425,8 @@ struct anv_shader { uint32_t anv_shader_get_scratch_surf(struct anv_batch *batch, struct anv_device *device, - struct anv_shader *shader, + mesa_shader_stage stage, + uint32_t total_scratch, bool protected); extern struct vk_device_shader_ops anv_device_shader_ops; diff --git a/src/intel/vulkan/genX_cmd_compute.c b/src/intel/vulkan/genX_cmd_compute.c index 9808cd7b260..5e94688cfdf 100644 --- a/src/intel/vulkan/genX_cmd_compute.c +++ b/src/intel/vulkan/genX_cmd_compute.c @@ -55,19 +55,12 @@ genX(cmd_buffer_ensure_cfe_state)(struct anv_cmd_buffer *cmd_buffer, anv_batch_emit(&cmd_buffer->batch, GENX(CFE_STATE), cfe) { cfe.MaximumNumberofThreads = devinfo->max_cs_threads * devinfo->subslice_total; - uint32_t scratch_surf; - struct anv_scratch_pool *scratch_pool = - (cmd_buffer->vk.pool->flags & VK_COMMAND_POOL_CREATE_PROTECTED_BIT) ? - &cmd_buffer->device->protected_scratch_pool : - &cmd_buffer->device->scratch_pool; - struct anv_bo *scratch_bo = - anv_scratch_pool_alloc(cmd_buffer->device, scratch_pool, - MESA_SHADER_COMPUTE, - total_scratch); - anv_reloc_list_add_bo(cmd_buffer->batch.relocs, scratch_bo); - scratch_surf = anv_scratch_pool_get_surf(cmd_buffer->device, scratch_pool, - total_scratch); - cfe.ScratchSpaceBuffer = scratch_surf >> ANV_SCRATCH_SPACE_SHIFT; + const bool protected = cmd_buffer->vk.pool->flags & VK_COMMAND_POOL_CREATE_PROTECTED_BIT; + cfe.ScratchSpaceBuffer = anv_shader_get_scratch_surf(&cmd_buffer->batch, + cmd_buffer->device, + MESA_SHADER_COMPUTE, + total_scratch, + protected); #if GFX_VER >= 20 switch (cmd_buffer->device->physical->instance->stack_ids) { case 256: cfe.StackIDControl = StackIDs256; break; @@ -1303,18 +1296,11 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer, btd.PerDSSMemoryBackedBufferSize = 6; btd.MemoryBackedBufferBasePointer = (struct anv_address) { .bo = device->btd_fifo_bo }; if (rt->scratch_size > 0) { - struct anv_bo *scratch_bo = - anv_scratch_pool_alloc(device, - &device->scratch_pool, - MESA_SHADER_COMPUTE, - rt->scratch_size); - anv_reloc_list_add_bo(cmd_buffer->batch.relocs, - scratch_bo); - uint32_t scratch_surf = - anv_scratch_pool_get_surf(cmd_buffer->device, - &device->scratch_pool, - rt->scratch_size); - btd.ScratchSpaceBuffer = scratch_surf >> ANV_SCRATCH_SPACE_SHIFT; + btd.ScratchSpaceBuffer = anv_shader_get_scratch_surf(&cmd_buffer->batch, + cmd_buffer->device, + MESA_SHADER_COMPUTE, + rt->scratch_size, + false);; } #if INTEL_NEEDS_WA_14017794102 || INTEL_NEEDS_WA_14023061436 btd.BTDMidthreadpreemption = false; diff --git a/src/intel/vulkan/genX_shader.c b/src/intel/vulkan/genX_shader.c index e341af1b3d3..5f16826709b 100644 --- a/src/intel/vulkan/genX_shader.c +++ b/src/intel/vulkan/genX_shader.c @@ -581,7 +581,9 @@ emit_vs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, vs.vs, vs_dwords, GENX(3DSTATE_VS), vs) { #if GFX_VERx10 >= 125 - vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else vs.PerThreadScratchSpace = get_scratch_space(shader); vs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -591,7 +593,9 @@ emit_vs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, vs.vs_protected, vs_dwords, GENX(3DSTATE_VS), vs) { #if GFX_VERx10 >= 125 - vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + vs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); #else vs.PerThreadScratchSpace = get_scratch_space(shader); vs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -659,7 +663,9 @@ emit_hs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, hs.hs, hs_dwords, GENX(3DSTATE_HS), hs) { #if GFX_VERx10 >= 125 - hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else hs.PerThreadScratchSpace = get_scratch_space(shader); hs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -669,7 +675,9 @@ emit_hs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, hs.hs_protected, hs_dwords, GENX(3DSTATE_HS), hs) { #if GFX_VERx10 >= 125 - hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + hs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else hs.PerThreadScratchSpace = get_scratch_space(shader); hs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -753,7 +761,9 @@ emit_ds_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ds.ds, ds_dwords, GENX(3DSTATE_DS), ds) { #if GFX_VERx10 >= 125 - ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else ds.PerThreadScratchSpace = get_scratch_space(shader); ds.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -763,7 +773,9 @@ emit_ds_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ds.ds_protected, ds_dwords, GENX(3DSTATE_DS), ds) { #if GFX_VERx10 >= 125 - ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + ds.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); #else ds.PerThreadScratchSpace = get_scratch_space(shader); ds.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -827,7 +839,9 @@ emit_gs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, gs.gs, gs_dwords, GENX(3DSTATE_GS), gs) { #if GFX_VERx10 >= 125 - gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else gs.PerThreadScratchSpace = get_scratch_space(shader); gs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -837,7 +851,9 @@ emit_gs_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, gs.gs_protected, gs_dwords, GENX(3DSTATE_GS), gs) { #if GFX_VERx10 >= 125 - gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + gs.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); #else gs.PerThreadScratchSpace = get_scratch_space(shader); gs.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -867,12 +883,16 @@ emit_task_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ts.control, task_control_dwords, GENX(3DSTATE_TASK_CONTROL), tc) { - tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); } if (device_needs_protected(device)) { anv_shader_emit_merge(batch, shader, ts.control_protected, task_control_dwords, GENX(3DSTATE_TASK_CONTROL), tc) { - tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + tc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); } } @@ -943,12 +963,16 @@ emit_mesh_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ms.control, mesh_control_dwords, GENX(3DSTATE_MESH_CONTROL), mc) { - mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); } if (device_needs_protected(device)) { anv_shader_emit_merge(batch, shader, ms.control_protected, mesh_control_dwords, GENX(3DSTATE_MESH_CONTROL), mc) { - mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + mc.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); } } @@ -1067,7 +1091,9 @@ emit_ps_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ps.ps, ps_dwords, GENX(3DSTATE_PS), ps) { #if GFX_VERx10 >= 125 - ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, false); + ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + false); #else ps.PerThreadScratchSpace = get_scratch_space(shader); ps.ScratchSpaceBasePointer = get_scratch_address(device, shader); @@ -1077,7 +1103,9 @@ emit_ps_shader(struct anv_batch *batch, anv_shader_emit_merge(batch, shader, ps.ps_protected, ps_dwords, GENX(3DSTATE_PS), ps) { #if GFX_VERx10 >= 125 - ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader, true); + ps.ScratchSpaceBuffer = anv_shader_get_scratch_surf(batch, device, shader->vk.stage, + shader->prog_data->total_scratch, + true); #else ps.PerThreadScratchSpace = get_scratch_space(shader); ps.ScratchSpaceBasePointer = get_scratch_address(device, shader);