diff --git a/.pick_status.json b/.pick_status.json index 7ab3f3fc864..b51ed0756c6 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -1734,7 +1734,7 @@ "description": "lavapipe: fix mesh+task binding with shader objects", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/gallium/frontends/lavapipe/lvp_execute.c b/src/gallium/frontends/lavapipe/lvp_execute.c index 9a9e187f122..e02f788c40b 100644 --- a/src/gallium/frontends/lavapipe/lvp_execute.c +++ b/src/gallium/frontends/lavapipe/lvp_execute.c @@ -730,19 +730,11 @@ handle_graphics_stages(struct rendering_state *state, VkShaderStageFlagBits shad break; case VK_SHADER_STAGE_TASK_BIT_EXT: state->inlines_dirty[MESA_SHADER_TASK] = state->shaders[MESA_SHADER_TASK]->inlines.can_inline; - state->dispatch_info.block[0] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[0]; - state->dispatch_info.block[1] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[1]; - state->dispatch_info.block[2] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[2]; if (!state->shaders[MESA_SHADER_TASK]->inlines.can_inline) state->pctx->bind_ts_state(state->pctx, state->shaders[MESA_SHADER_TASK]->shader_cso); break; case VK_SHADER_STAGE_MESH_BIT_EXT: state->inlines_dirty[MESA_SHADER_MESH] = state->shaders[MESA_SHADER_MESH]->inlines.can_inline; - if (!(shader_stages & VK_SHADER_STAGE_TASK_BIT_EXT)) { - state->dispatch_info.block[0] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[0]; - state->dispatch_info.block[1] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[1]; - state->dispatch_info.block[2] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[2]; - } if (!state->shaders[MESA_SHADER_MESH]->inlines.can_inline) state->pctx->bind_ms_state(state->pctx, state->shaders[MESA_SHADER_MESH]->shader_cso); break; @@ -3863,9 +3855,24 @@ handle_shaders(struct vk_cmd_queue_entry *cmd, struct rendering_state *state) } } +static void +update_mesh_state(struct rendering_state *state) +{ + if (state->shaders[MESA_SHADER_TASK]) { + state->dispatch_info.block[0] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[0]; + state->dispatch_info.block[1] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[1]; + state->dispatch_info.block[2] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[2]; + } else { + state->dispatch_info.block[0] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[0]; + state->dispatch_info.block[1] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[1]; + state->dispatch_info.block[2] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[2]; + } +} + static void handle_draw_mesh_tasks(struct vk_cmd_queue_entry *cmd, struct rendering_state *state) { + update_mesh_state(state); state->dispatch_info.grid[0] = cmd->u.draw_mesh_tasks_ext.group_count_x; state->dispatch_info.grid[1] = cmd->u.draw_mesh_tasks_ext.group_count_y; state->dispatch_info.grid[2] = cmd->u.draw_mesh_tasks_ext.group_count_z; @@ -3880,6 +3887,7 @@ static void handle_draw_mesh_tasks(struct vk_cmd_queue_entry *cmd, static void handle_draw_mesh_tasks_indirect(struct vk_cmd_queue_entry *cmd, struct rendering_state *state) { + update_mesh_state(state); state->dispatch_info.indirect = lvp_buffer_from_handle(cmd->u.draw_mesh_tasks_indirect_ext.buffer)->bo; state->dispatch_info.indirect_offset = cmd->u.draw_mesh_tasks_indirect_ext.offset; state->dispatch_info.indirect_stride = cmd->u.draw_mesh_tasks_indirect_ext.stride; @@ -3890,6 +3898,7 @@ static void handle_draw_mesh_tasks_indirect(struct vk_cmd_queue_entry *cmd, static void handle_draw_mesh_tasks_indirect_count(struct vk_cmd_queue_entry *cmd, struct rendering_state *state) { + update_mesh_state(state); state->dispatch_info.indirect = lvp_buffer_from_handle(cmd->u.draw_mesh_tasks_indirect_count_ext.buffer)->bo; state->dispatch_info.indirect_offset = cmd->u.draw_mesh_tasks_indirect_count_ext.offset; state->dispatch_info.indirect_stride = cmd->u.draw_mesh_tasks_indirect_count_ext.stride;