lavapipe: add support for task/mesh shader stages in various places

this bumps the LVP_SHADER_STAGES to allow task/mesh shaders to be used,
and adds them to various state binding and execution places.

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23066>
This commit is contained in:
Dave Airlie 2023-05-17 12:10:45 +10:00
parent 092b1daaf6
commit 5c6c226f5a
4 changed files with 85 additions and 7 deletions

View file

@ -192,6 +192,8 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_CreateDescriptorSetLayout(
VK_SHADER_STAGE_GEOMETRY_BIT,
VK_SHADER_STAGE_FRAGMENT_BIT,
VK_SHADER_STAGE_COMPUTE_BIT,
VK_SHADER_STAGE_TASK_BIT_EXT,
VK_SHADER_STAGE_MESH_BIT_EXT,
};
lvp_forall_stage(i) {
uint16_t const_buffer_count = 0;
@ -256,6 +258,8 @@ lvp_pipeline_layout_create(struct lvp_device *device,
VK_SHADER_STAGE_GEOMETRY_BIT,
VK_SHADER_STAGE_FRAGMENT_BIT,
VK_SHADER_STAGE_COMPUTE_BIT,
VK_SHADER_STAGE_TASK_BIT_EXT,
VK_SHADER_STAGE_MESH_BIT_EXT,
};
lvp_forall_stage(i) {

View file

@ -380,6 +380,12 @@ update_inline_shader_state(struct rendering_state *state, enum pipe_shader_type
case MESA_SHADER_GEOMETRY:
state->pctx->bind_gs_state(state->pctx, shader_state);
break;
case MESA_SHADER_TASK:
state->pctx->bind_ts_state(state->pctx, shader_state);
break;
case MESA_SHADER_MESH:
state->pctx->bind_ms_state(state->pctx, shader_state);
break;
case MESA_SHADER_FRAGMENT:
state->pctx->bind_fs_state(state->pctx, shader_state);
state->noop_fs_bound = false;
@ -740,6 +746,24 @@ handle_graphics_stages(struct rendering_state *state, VkShaderStageFlagBits shad
if (!dynamic_tess_origin)
state->tess_ccw = false;
break;
case VK_SHADER_STAGE_TASK_BIT_EXT:
state->inlines_dirty[MESA_SHADER_TASK] = state->shaders[MESA_SHADER_TASK]->inlines.can_inline;
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_TASK]->pipeline_nir->nir->info.workgroup_size[2];
if (!state->shaders[MESA_SHADER_TASK]->inlines.can_inline)
state->pctx->bind_ts_state(state->pctx, state->shaders[MESA_SHADER_TASK]->shader_cso);
break;
case VK_SHADER_STAGE_MESH_BIT_EXT:
state->inlines_dirty[MESA_SHADER_MESH] = state->shaders[MESA_SHADER_MESH]->inlines.can_inline;
if (!(shader_stages & VK_SHADER_STAGE_TASK_BIT_EXT)) {
state->dispatch_info.block[0] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[0];
state->dispatch_info.block[1] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[1];
state->dispatch_info.block[2] = state->shaders[MESA_SHADER_MESH]->pipeline_nir->nir->info.workgroup_size[2];
}
if (!state->shaders[MESA_SHADER_MESH]->inlines.can_inline)
state->pctx->bind_ms_state(state->pctx, state->shaders[MESA_SHADER_MESH]->shader_cso);
break;
default:
assert(0);
break;
@ -778,6 +802,14 @@ unbind_graphics_stages(struct rendering_state *state, VkShaderStageFlagBits shad
if (state->shaders[MESA_SHADER_VERTEX])
state->pctx->bind_vs_state(state->pctx, NULL);
break;
case MESA_SHADER_TASK:
if (state->shaders[MESA_SHADER_TASK])
state->pctx->bind_ts_state(state->pctx, NULL);
break;
case MESA_SHADER_MESH:
if (state->shaders[MESA_SHADER_MESH])
state->pctx->bind_ms_state(state->pctx, NULL);
break;
default:
unreachable("what stage is this?!");
}
@ -805,7 +837,11 @@ static void handle_graphics_pipeline(struct vk_cmd_queue_entry *cmd,
const struct vk_graphics_pipeline_state *ps = &pipeline->graphics_state;
lvp_pipeline_shaders_compile(pipeline);
bool dynamic_tess_origin = BITSET_TEST(ps->dynamic, MESA_VK_DYNAMIC_TS_DOMAIN_ORIGIN);
unbind_graphics_stages(state, (~pipeline->graphics_state.shader_stages) & VK_SHADER_STAGE_ALL_GRAPHICS);
unbind_graphics_stages(state,
(~pipeline->graphics_state.shader_stages) &
(VK_SHADER_STAGE_ALL_GRAPHICS |
VK_SHADER_STAGE_TASK_BIT_EXT |
VK_SHADER_STAGE_MESH_BIT_EXT));
lvp_forall_gfx_stage(sh) {
if (pipeline->graphics_state.shader_stages & mesa_to_vk_shader_stage(sh))
state->shaders[sh] = &pipeline->shaders[sh];
@ -1443,6 +1479,12 @@ static void handle_descriptor_sets(struct vk_cmd_queue_entry *cmd,
if (set->layout->shader_stages & VK_SHADER_STAGE_FRAGMENT_BIT)
handle_set_stage(state, &dyn_info, set, MESA_SHADER_FRAGMENT, MESA_SHADER_FRAGMENT);
if (set->layout->shader_stages & VK_SHADER_STAGE_TASK_BIT_EXT)
handle_set_stage(state, &dyn_info, set, MESA_SHADER_TASK, MESA_SHADER_TASK);
if (set->layout->shader_stages & VK_SHADER_STAGE_MESH_BIT_EXT)
handle_set_stage(state, &dyn_info, set, MESA_SHADER_MESH, MESA_SHADER_MESH);
increment_dyn_info(&dyn_info, layout->vk.set_layouts[bds->first_set + i], true);
}
}
@ -2842,12 +2884,16 @@ static void handle_push_constants(struct vk_cmd_queue_entry *cmd,
state->pcbuf_dirty[MESA_SHADER_TESS_CTRL] |= (stage_flags & VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT) > 0;
state->pcbuf_dirty[MESA_SHADER_TESS_EVAL] |= (stage_flags & VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) > 0;
state->pcbuf_dirty[MESA_SHADER_COMPUTE] |= (stage_flags & VK_SHADER_STAGE_COMPUTE_BIT) > 0;
state->pcbuf_dirty[MESA_SHADER_TASK] |= (stage_flags & VK_SHADER_STAGE_TASK_BIT_EXT) > 0;
state->pcbuf_dirty[MESA_SHADER_MESH] |= (stage_flags & VK_SHADER_STAGE_MESH_BIT_EXT) > 0;
state->inlines_dirty[MESA_SHADER_VERTEX] |= (stage_flags & VK_SHADER_STAGE_VERTEX_BIT) > 0;
state->inlines_dirty[MESA_SHADER_FRAGMENT] |= (stage_flags & VK_SHADER_STAGE_FRAGMENT_BIT) > 0;
state->inlines_dirty[MESA_SHADER_GEOMETRY] |= (stage_flags & VK_SHADER_STAGE_GEOMETRY_BIT) > 0;
state->inlines_dirty[MESA_SHADER_TESS_CTRL] |= (stage_flags & VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT) > 0;
state->inlines_dirty[MESA_SHADER_TESS_EVAL] |= (stage_flags & VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) > 0;
state->inlines_dirty[MESA_SHADER_COMPUTE] |= (stage_flags & VK_SHADER_STAGE_COMPUTE_BIT) > 0;
state->inlines_dirty[MESA_SHADER_TASK] |= (stage_flags & VK_SHADER_STAGE_TASK_BIT_EXT) > 0;
state->inlines_dirty[MESA_SHADER_MESH] |= (stage_flags & VK_SHADER_STAGE_MESH_BIT_EXT) > 0;
}
static void lvp_execute_cmd_buffer(struct lvp_cmd_buffer *cmd_buffer,
@ -3470,6 +3516,16 @@ static void handle_push_descriptor_set_generic(struct vk_cmd_push_descriptor_set
MESA_SHADER_TESS_EVAL, MESA_SHADER_TESS_EVAL,
j, desc->descriptor_type,
info);
if (layout->shader_stages & VK_SHADER_STAGE_TASK_BIT_EXT)
handle_descriptor(state, &dyn_info, binding,
MESA_SHADER_TASK, MESA_SHADER_TASK,
j, desc->descriptor_type,
info);
if (layout->shader_stages & VK_SHADER_STAGE_MESH_BIT_EXT)
handle_descriptor(state, &dyn_info, binding,
MESA_SHADER_MESH, MESA_SHADER_MESH,
j, desc->descriptor_type,
info);
}
info_idx += desc->descriptor_count;
}

View file

@ -54,6 +54,8 @@ shader_destroy(struct lvp_device *device, struct lvp_shader *shader)
device->queue.ctx->delete_gs_state,
device->queue.ctx->delete_fs_state,
device->queue.ctx->delete_compute_state,
device->queue.ctx->delete_ts_state,
device->queue.ctx->delete_ms_state,
};
set_foreach(&shader->inlines.variants, entry) {
struct lvp_inline_variant *variant = (void*)entry->key;
@ -363,7 +365,7 @@ static VkResult
compile_spirv(struct lvp_device *pdevice, const VkPipelineShaderStageCreateInfo *sinfo, nir_shader **nir)
{
gl_shader_stage stage = vk_to_mesa_shader_stage(sinfo->stage);
assert(stage <= MESA_SHADER_COMPUTE && stage != MESA_SHADER_NONE);
assert(stage <= LVP_SHADER_STAGES && stage != MESA_SHADER_NONE);
VkResult result;
const struct spirv_to_nir_options spirv_options = {
@ -409,6 +411,7 @@ compile_spirv(struct lvp_device *pdevice, const VkPipelineShaderStageCreateInfo
.int8 = true,
.float16 = true,
.demote_to_helper_invocation = true,
.mesh_shading = true,
},
.ubo_addr_format = nir_address_format_32bit_index_offset,
.ssbo_addr_format = nir_address_format_32bit_index_offset,
@ -439,7 +442,7 @@ static void
lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_shader *shader, struct lvp_pipeline_layout *layout)
{
if (nir->info.stage != MESA_SHADER_TESS_CTRL)
NIR_PASS_V(nir, remove_scoped_barriers, nir->info.stage == MESA_SHADER_COMPUTE);
NIR_PASS_V(nir, remove_scoped_barriers, nir->info.stage == MESA_SHADER_COMPUTE || nir->info.stage == MESA_SHADER_MESH || nir->info.stage == MESA_SHADER_TASK);
const struct nir_lower_sysvals_to_varyings_options sysvals_to_varyings = {
.frag_coord = true,
@ -485,11 +488,19 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_shader
nir_var_mem_global,
nir_address_format_64bit_global);
if (nir->info.stage == MESA_SHADER_COMPUTE) {
if (nir->info.stage == MESA_SHADER_COMPUTE ||
nir->info.stage == MESA_SHADER_TASK ||
nir->info.stage == MESA_SHADER_MESH) {
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared, shared_var_info);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared, nir_address_format_32bit_offset);
}
if (nir->info.stage == MESA_SHADER_TASK ||
nir->info.stage == MESA_SHADER_MESH) {
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_task_payload, shared_var_info);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_task_payload, nir_address_format_32bit_offset);
}
NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_shader_temp, NULL);
if (nir->info.stage == MESA_SHADER_VERTEX ||
@ -534,7 +545,7 @@ lvp_shader_compile_to_ir(struct lvp_pipeline *pipeline,
{
struct lvp_device *pdevice = pipeline->device;
gl_shader_stage stage = vk_to_mesa_shader_stage(sinfo->stage);
assert(stage <= MESA_SHADER_COMPUTE && stage != MESA_SHADER_NONE);
assert(stage <= LVP_SHADER_STAGES && stage != MESA_SHADER_NONE);
struct lvp_shader *shader = &pipeline->shaders[stage];
nir_shader *nir;
VkResult result = compile_spirv(pdevice, sinfo, &nir);
@ -623,6 +634,8 @@ lvp_pipeline_xfb_init(struct lvp_pipeline *pipeline)
stage = MESA_SHADER_GEOMETRY;
else if (pipeline->shaders[MESA_SHADER_TESS_EVAL].pipeline_nir)
stage = MESA_SHADER_TESS_EVAL;
else if (pipeline->shaders[MESA_SHADER_MESH].pipeline_nir)
stage = MESA_SHADER_MESH;
pipeline->last_vertex = stage;
lvp_shader_xfb_init(&pipeline->shaders[stage]);
}
@ -653,6 +666,10 @@ lvp_shader_compile_stage(struct lvp_device *device, struct lvp_shader *shader, n
return device->queue.ctx->create_tcs_state(device->queue.ctx, &shstate);
case MESA_SHADER_TESS_EVAL:
return device->queue.ctx->create_tes_state(device->queue.ctx, &shstate);
case MESA_SHADER_TASK:
return device->queue.ctx->create_ts_state(device->queue.ctx, &shstate);
case MESA_SHADER_MESH:
return device->queue.ctx->create_ms_state(device->queue.ctx, &shstate);
default:
unreachable("illegal shader");
break;
@ -834,6 +851,7 @@ lvp_graphics_pipeline_init(struct lvp_pipeline *pipeline,
pipeline->disable_multisample = p->disable_multisample;
pipeline->line_rectangular = p->line_rectangular;
memcpy(pipeline->shaders, p->shaders, sizeof(struct lvp_shader) * 4);
memcpy(&pipeline->shaders[MESA_SHADER_TASK], &p->shaders[MESA_SHADER_TASK], sizeof(struct lvp_shader) * 2);
lvp_forall_gfx_stage(i) {
if (i == MESA_SHADER_FRAGMENT)
continue;
@ -1161,7 +1179,7 @@ create_shader_object(struct lvp_device *device, const VkShaderCreateInfoEXT *pCr
{
nir_shader *nir = NULL;
gl_shader_stage stage = vk_to_mesa_shader_stage(pCreateInfo->stage);
assert(stage <= MESA_SHADER_COMPUTE && stage != MESA_SHADER_NONE);
assert(stage <= LVP_SHADER_STAGES && stage != MESA_SHADER_NONE);
if (pCreateInfo->codeType == VK_SHADER_CODE_TYPE_SPIRV_EXT) {
VkShaderModuleCreateInfo minfo = {
VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,

View file

@ -113,7 +113,7 @@ void __lvp_finishme(const char *file, int line, const char *format, ...)
return; \
} while (0)
#define LVP_SHADER_STAGES MESA_SHADER_STAGES
#define LVP_SHADER_STAGES (MESA_SHADER_MESH + 1)
#define LVP_STAGE_MASK BITFIELD_MASK(LVP_SHADER_STAGES)
#define LVP_STAGE_MASK_GFX (BITFIELD_MASK(LVP_SHADER_STAGES) & ~BITFIELD_BIT(MESA_SHADER_COMPUTE))