anv: Implement RT shader group handle capture/replay

Signed-off-by: Michael Cheng <michael.cheng@intel.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33022>
This commit is contained in:
Michael Cheng 2026-01-23 10:24:33 +02:00 committed by Marge Bot
parent 51966cc452
commit 4f82dfc5f5
6 changed files with 183 additions and 33 deletions

View file

@ -298,7 +298,8 @@ void genX(batch_emit_fast_color_dummy_blit)(struct anv_batch *batch,
(struct GENX(BINDLESS_SHADER_RECORD)) { \
.OffsetToLocalArguments = (local_arg_offset) / 8, \
.BindlessShaderDispatchMode = RT_SIMD16, \
.KernelStartPointer = shader->kernel.offset, \
.KernelStartPointer = shader->replay_kernel.alloc_size != 0 ? \
shader->replay_kernel.offset : shader->kernel.offset, \
.RegistersPerThread = ptl_register_blocks(prog_data->base.grf_used), \
}; \
})
@ -313,7 +314,8 @@ void genX(batch_emit_fast_color_dummy_blit)(struct anv_batch *batch,
.OffsetToLocalArguments = (local_arg_offset) / 8, \
.BindlessShaderDispatchMode = \
prog_data->simd_size == 16 ? RT_SIMD16 : RT_SIMD8, \
.KernelStartPointer = shader->kernel.offset, \
.KernelStartPointer = shader->replay_kernel.alloc_size != 0 ? \
shader->replay_kernel.offset : shader->kernel.offset, \
}; \
})
#endif

View file

@ -715,7 +715,7 @@ get_features(const struct anv_physical_device *pdevice,
/* VK_KHR_ray_tracing_pipeline */
.rayTracingPipeline = rt_enabled,
.rayTracingPipelineShaderGroupHandleCaptureReplay = false,
.rayTracingPipelineShaderGroupHandleCaptureReplay = true,
.rayTracingPipelineShaderGroupHandleCaptureReplayMixed = false,
.rayTracingPipelineTraceRaysIndirect = rt_enabled,
.rayTraversalPrimitiveCulling = rt_enabled,
@ -1622,7 +1622,8 @@ get_properties(const struct anv_physical_device *pdevice,
/* MemRay::hitGroupSRBasePtr requires 16B alignment */
props->shaderGroupBaseAlignment = 16;
props->shaderGroupHandleAlignment = 16;
props->shaderGroupHandleCaptureReplaySize = 32;
props->shaderGroupHandleCaptureReplaySize =
sizeof(struct anv_shader_group_rt_replay);
props->maxRayDispatchInvocationCount = 1U << 30; /* required min limit */
props->maxRayHitAttributeSize = BRW_RT_SIZEOF_HIT_ATTRIB_DATA;
}

View file

@ -1280,13 +1280,20 @@ struct anv_shader_alloc anv_shader_heap_alloc(struct anv_shader_heap *heap,
uint64_t size,
uint64_t align,
bool capture_replay,
uint64_t requested_addr);
uint64_t requested_offset);
void anv_shader_heap_free(struct anv_shader_heap *heap, struct anv_shader_alloc alloc);
void anv_shader_heap_upload(struct anv_shader_heap *heap,
struct anv_shader_alloc alloc,
const void *data, uint64_t size);
struct anv_shader_group_rt_replay {
uint64_t general;
uint64_t closest_hit;
uint64_t any_hit;
uint64_t intersection;
};
struct anv_shader {
struct vk_shader vk;
@ -1315,6 +1322,10 @@ struct anv_shader {
*/
struct anv_embedded_sampler **embedded_samplers;
/* Mutex to protect the lazy replay allocation */
simple_mtx_t replay_mutex;
struct anv_shader_alloc replay_kernel;
struct anv_reloc_list relocs;
union {
@ -1391,6 +1402,23 @@ struct anv_shader {
extern struct vk_device_shader_ops anv_device_shader_ops;
void anv_write_rt_shader_group(struct vk_device *vk_device,
VkRayTracingShaderGroupTypeKHR type,
const struct vk_shader **shaders,
uint32_t shader_count,
void *output);
void anv_write_rt_shader_group_replay_handle(struct vk_device *vk_device,
const struct vk_shader **shaders,
uint32_t shader_count,
void *output);
void anv_replay_rt_shader_group(struct vk_device *vk_device,
VkRayTracingShaderGroupTypeKHR type,
uint32_t shader_count,
struct vk_shader **vk_shaders,
const void *replay_data);
/* Physical device */
struct anv_queue_family {

View file

@ -24,7 +24,10 @@ anv_shader_destroy(struct vk_device *vk_device,
anv_embedded_sampler_unref(device, shader->embedded_samplers[i]);
anv_shader_heap_free(&device->shader_heap, shader->kernel);
if (shader->replay_kernel.alloc_size != 0)
anv_shader_heap_free(&device->shader_heap, shader->replay_kernel);
anv_reloc_list_finish(&shader->relocs);
simple_mtx_destroy(&shader->replay_mutex);
vk_shader_free(vk_device, pAllocator, vk_shader);
}
@ -612,6 +615,8 @@ anv_shader_create(struct anv_device *device,
stage, pAllocator))
return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
simple_mtx_init(&shader->replay_mutex, mtx_plain);
VkResult result;
if (shader_data->bind_map.embedded_sampler_count > 0) {
shader->embedded_samplers = embedded_samplers;
@ -721,3 +726,139 @@ anv_shader_create(struct anv_device *device,
vk_shader_free(&device->vk, pAllocator, &shader->vk);
return result;
}
void
anv_write_rt_shader_group(struct vk_device *vk_device,
VkRayTracingShaderGroupTypeKHR type,
const struct vk_shader **shaders,
uint32_t shader_count,
void *output)
{
struct anv_device *device =
container_of(vk_device, struct anv_device, vk);
anv_genX(device->info, write_rt_shader_group)(device, type,
shaders, shader_count,
output);
}
void
anv_write_rt_shader_group_replay_handle(struct vk_device *vk_device,
const struct vk_shader **vk_shaders,
uint32_t shader_count,
void *output)
{
assert(shader_count <= 3);
struct anv_shader_group_rt_replay *replay_data = output;
memset(replay_data, 0, sizeof(*replay_data));
for (uint32_t i = 0; i < shader_count; i++) {
if (!vk_shaders[i])
continue;
const struct anv_shader *shader =
container_of(vk_shaders[i], struct anv_shader, vk);
switch (shader->vk.stage) {
case MESA_SHADER_RAYGEN:
case MESA_SHADER_CALLABLE:
case MESA_SHADER_MISS:
replay_data->general = shader->replay_kernel.offset;
break;
case MESA_SHADER_ANY_HIT:
replay_data->any_hit = shader->replay_kernel.offset;
break;
case MESA_SHADER_CLOSEST_HIT:
replay_data->closest_hit = shader->replay_kernel.offset;
break;
case MESA_SHADER_INTERSECTION:
replay_data->intersection = shader->replay_kernel.offset;
break;
default:
UNREACHABLE("invalid stage");
}
}
}
void
anv_replay_rt_shader_group(struct vk_device *vk_device,
VkRayTracingShaderGroupTypeKHR type,
uint32_t shader_count,
struct vk_shader **vk_shaders,
const void *replay_data)
{
struct anv_device *device = container_of(vk_device, struct anv_device, vk);
const struct anv_shader_group_rt_replay *data = replay_data;
for (uint32_t i = 0; i < shader_count; i++) {
struct anv_shader *shader =
container_of(vk_shaders[i], struct anv_shader, vk);
uint64_t offset = 0;
if (data != NULL) {
switch (shader->vk.stage) {
case MESA_SHADER_RAYGEN:
case MESA_SHADER_CALLABLE:
case MESA_SHADER_MISS:
offset = data->general;
break;
case MESA_SHADER_ANY_HIT:
offset = data->any_hit;
break;
case MESA_SHADER_CLOSEST_HIT:
offset = data->closest_hit;
break;
case MESA_SHADER_INTERSECTION:
offset = data->intersection;
break;
default:
UNREACHABLE("invalid stage");
}
}
switch (type) {
case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
break;
case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
/* Anyhit is merged into intersection */
if (shader->vk.stage == MESA_SHADER_ANY_HIT)
continue;
break;
default:
UNREACHABLE("invalid group");
}
simple_mtx_lock(&shader->replay_mutex);
if (shader->replay_kernel.alloc_size == 0) {
shader->replay_kernel = anv_shader_heap_alloc(
&device->shader_heap,
shader->prog_data->program_size,
64, true, offset);
assert(shader->replay_kernel.alloc_size != 0);
/* TODO: make a copy of the code to leave it untouched */
VkResult result =
anv_shader_reloc(device, shader->code, shader, &vk_device->alloc);
assert(result == VK_SUCCESS);
anv_shader_heap_upload(&device->shader_heap,
shader->replay_kernel, shader->code,
shader->prog_data->program_size);
}
simple_mtx_unlock(&shader->replay_mutex);
}
}

View file

@ -2114,30 +2114,6 @@ end:
return result;
}
static void
anv_write_rt_shader_group(struct vk_device *vk_device,
VkRayTracingShaderGroupTypeKHR type,
const struct vk_shader **shaders,
uint32_t shader_count,
void *output)
{
struct anv_device *device =
container_of(vk_device, struct anv_device, vk);
anv_genX(device->info, write_rt_shader_group)(device, type,
shaders, shader_count,
output);
}
static void
anv_write_rt_shader_group_replay_handle(struct vk_device *device,
const struct vk_shader **shaders,
uint32_t shader_count,
void *output)
{
UNREACHABLE("Unimplemented");
}
struct vk_device_shader_ops anv_device_shader_ops = {
.get_nir_options = anv_shader_get_nir_options,
.get_spirv_options = anv_shader_get_spirv_options,
@ -2146,6 +2122,7 @@ struct vk_device_shader_ops anv_device_shader_ops = {
.hash_state = anv_shader_hash_state,
.compile = anv_shader_compile,
.deserialize = anv_shader_deserialize,
.replay_rt_shader_group = anv_replay_rt_shader_group,
.write_rt_shader_group = anv_write_rt_shader_group,
.write_rt_shader_group_replay_handle = anv_write_rt_shader_group_replay_handle,
.cmd_bind_shaders = anv_cmd_buffer_bind_shaders,

View file

@ -105,7 +105,7 @@ anv_shader_heap_alloc(struct anv_shader_heap *heap,
uint64_t size,
uint64_t align,
bool capture_replay,
uint64_t requested_addr)
uint64_t requested_offset)
{
assert(align <= heap->base_chunk_size);
assert(size <= heap->base_chunk_size);
@ -117,10 +117,11 @@ anv_shader_heap_alloc(struct anv_shader_heap *heap,
heap->vma.nospan_shift++;
uint64_t addr = 0;
if (requested_addr) {
if (requested_offset) {
if (util_vma_heap_alloc_addr(&heap->vma,
heap->va_range.addr + requested_addr, size)) {
addr = requested_addr;
heap->va_range.addr + requested_offset,
size)) {
addr = heap->va_range.addr + requested_offset;
}
} else {
if (capture_replay) {