diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp index 8d816d037b3..5780005d509 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp +++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp @@ -170,7 +170,7 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel, NIR_PASS_V(nir, d3d12_lower_yflip); } NIR_PASS_V(nir, nir_lower_packed_ubo_loads); - NIR_PASS_V(nir, d3d12_lower_load_first_vertex); + NIR_PASS_V(nir, d3d12_lower_load_draw_params); NIR_PASS_V(nir, d3d12_lower_state_vars, shader); NIR_PASS_V(nir, dxil_nir_lower_bool_input); NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil); diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h b/src/gallium/drivers/d3d12/d3d12_compiler.h index 703b058bb48..40cc4fba286 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.h +++ b/src/gallium/drivers/d3d12/d3d12_compiler.h @@ -43,7 +43,7 @@ extern "C" { enum d3d12_state_var { D3D12_STATE_VAR_Y_FLIP = 0, D3D12_STATE_VAR_PT_SPRITE, - D3D12_STATE_VAR_FIRST_VERTEX, + D3D12_STATE_VAR_DRAW_PARAMS, D3D12_STATE_VAR_DEPTH_TRANSFORM, D3D12_MAX_GRAPHICS_STATE_VARS, diff --git a/src/gallium/drivers/d3d12/d3d12_draw.cpp b/src/gallium/drivers/d3d12/d3d12_draw.cpp index 7bd739f9c72..72b263e4ace 100644 --- a/src/gallium/drivers/d3d12/d3d12_draw.cpp +++ b/src/gallium/drivers/d3d12/d3d12_draw.cpp @@ -340,6 +340,7 @@ fill_image_descriptors(struct d3d12_context *ctx, static unsigned fill_graphics_state_vars(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo, + unsigned drawid, const struct pipe_draw_start_count_bias *draw, struct d3d12_shader *shader, uint32_t *values) @@ -361,8 +362,11 @@ fill_graphics_state_vars(struct d3d12_context *ctx, ptr[3] = fui(D3D12_MAX_POINT_SIZE); size += 4; break; - case D3D12_STATE_VAR_FIRST_VERTEX: + case D3D12_STATE_VAR_DRAW_PARAMS: ptr[0] = dinfo->index_size ? draw->index_bias : draw->start; + ptr[1] = dinfo->start_instance; + ptr[2] = drawid; + ptr[3] = dinfo->index_size ? -1 : 0; size += 4; break; case D3D12_STATE_VAR_DEPTH_TRANSFORM: @@ -500,6 +504,7 @@ update_shader_stage_root_parameters(struct d3d12_context *ctx, static unsigned update_graphics_root_parameters(struct d3d12_context *ctx, const struct pipe_draw_info *dinfo, + unsigned drawid, const struct pipe_draw_start_count_bias *draw, D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES], int root_desc_indices[MAX_DESCRIPTOR_TABLES]) @@ -516,7 +521,7 @@ update_graphics_root_parameters(struct d3d12_context *ctx, /* TODO Don't always update state vars */ if (shader_sel->current->num_state_vars > 0) { uint32_t constants[D3D12_MAX_GRAPHICS_STATE_VARS * 4]; - unsigned size = fill_graphics_state_vars(ctx, dinfo, draw, shader_sel->current, constants); + unsigned size = fill_graphics_state_vars(ctx, dinfo, drawid, draw, shader_sel->current, constants); ctx->cmdlist->SetGraphicsRoot32BitConstants(num_params, size, constants, 0); num_params++; } @@ -847,7 +852,8 @@ d3d12_draw_vbo(struct pipe_context *pctx, D3D12_GPU_DESCRIPTOR_HANDLE root_desc_tables[MAX_DESCRIPTOR_TABLES]; int root_desc_indices[MAX_DESCRIPTOR_TABLES]; - unsigned num_root_descriptors = update_graphics_root_parameters(ctx, dinfo, &draws[0], root_desc_tables, root_desc_indices); + unsigned num_root_descriptors = update_graphics_root_parameters(ctx, dinfo, drawid_offset, &draws[0], + root_desc_tables, root_desc_indices); bool need_zero_one_depth_range = d3d12_need_zero_one_depth_range(ctx); if (need_zero_one_depth_range != ctx->need_zero_one_depth_range) { diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.c b/src/gallium/drivers/d3d12/d3d12_nir_passes.c index 85a79d14575..d0927f67c88 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.c +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.c @@ -318,51 +318,41 @@ d3d12_lower_uint_cast(nir_shader *nir, bool is_signed) } static bool -lower_load_first_vertex(nir_builder *b, nir_instr *instr, nir_variable **first_vertex) +lower_load_draw_params(nir_builder *b, nir_instr *instr, void *draw_params) { if (instr->type != nir_instr_type_intrinsic) return false; nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); - if (intr->intrinsic != nir_intrinsic_load_first_vertex) + if (intr->intrinsic != nir_intrinsic_load_first_vertex && + intr->intrinsic != nir_intrinsic_load_base_instance && + intr->intrinsic != nir_intrinsic_load_draw_id && + intr->intrinsic != nir_intrinsic_load_is_indexed_draw) return false; b->cursor = nir_before_instr(&intr->instr); - nir_ssa_def *load = d3d12_get_state_var(b, D3D12_STATE_VAR_FIRST_VERTEX, "d3d12_FirstVertex", - glsl_uint_type(), first_vertex); - nir_ssa_def_rewrite_uses(&intr->dest.ssa, load); + nir_ssa_def *load = d3d12_get_state_var(b, D3D12_STATE_VAR_DRAW_PARAMS, "d3d12_DrawParams", + glsl_uvec4_type(), draw_params); + unsigned channel = intr->intrinsic == nir_intrinsic_load_first_vertex ? 0 : + intr->intrinsic == nir_intrinsic_load_base_instance ? 1 : + intr->intrinsic == nir_intrinsic_load_draw_id ? 2 : 3; + nir_ssa_def_rewrite_uses(&intr->dest.ssa, nir_channel(b, load, channel)); nir_instr_remove(instr); return true; } bool -d3d12_lower_load_first_vertex(struct nir_shader *nir) +d3d12_lower_load_draw_params(struct nir_shader *nir) { - nir_variable *first_vertex = NULL; - bool progress = false; - + nir_variable *draw_params = NULL; if (nir->info.stage != MESA_SHADER_VERTEX) return false; - nir_foreach_function(function, nir) { - if (function->impl) { - nir_builder b; - nir_builder_init(&b, function->impl); - - nir_foreach_block(block, function->impl) { - nir_foreach_instr_safe(instr, block) { - progress |= lower_load_first_vertex(&b, instr, &first_vertex); - } - } - - nir_metadata_preserve(function->impl, nir_metadata_block_index | - nir_metadata_dominance); - } - } - return progress; + return nir_shader_instructions_pass(nir, lower_load_draw_params, + nir_metadata_block_index | nir_metadata_dominance, &draw_params); } static void diff --git a/src/gallium/drivers/d3d12/d3d12_nir_passes.h b/src/gallium/drivers/d3d12/d3d12_nir_passes.h index 0c81d2e262b..f5aa93c03a6 100644 --- a/src/gallium/drivers/d3d12/d3d12_nir_passes.h +++ b/src/gallium/drivers/d3d12/d3d12_nir_passes.h @@ -62,7 +62,7 @@ void d3d12_lower_depth_range(nir_shader *nir); bool -d3d12_lower_load_first_vertex(nir_shader *nir); +d3d12_lower_load_draw_params(nir_shader *nir); bool d3d12_lower_compute_state_vars(nir_shader *nir);